mirror of
https://github.com/redis/go-redis.git
synced 2025-12-02 06:22:31 +03:00
Merge branch 'master' into implement-tls-url-parameters-pr2076
This commit is contained in:
2
.github/release-drafter-config.yml
vendored
2
.github/release-drafter-config.yml
vendored
@@ -36,6 +36,8 @@ categories:
|
||||
change-template: '- $TITLE (#$NUMBER)'
|
||||
exclude-labels:
|
||||
- 'skip-changelog'
|
||||
exclude-contributors:
|
||||
- 'dependabot'
|
||||
template: |
|
||||
# Changes
|
||||
|
||||
|
||||
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
|
||||
6
.github/workflows/codeql-analysis.yml
vendored
6
.github/workflows/codeql-analysis.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
uses: github/codeql-action/autobuild@v4
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 https://git.io/JvXDl
|
||||
@@ -64,4 +64,4 @@ jobs:
|
||||
# make release
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
|
||||
2
.github/workflows/doctests.yaml
vendored
2
.github/workflows/doctests.yaml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
|
||||
2
.github/workflows/spellcheck.yml
vendored
2
.github/workflows/spellcheck.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
- name: Check Spelling
|
||||
uses: rojopolis/spellcheck-github-actions@0.51.0
|
||||
uses: rojopolis/spellcheck-github-actions@0.52.0
|
||||
with:
|
||||
config_path: .github/spellcheck-settings.yml
|
||||
task_name: Markdown
|
||||
|
||||
2
.github/workflows/test-redis-enterprise.yml
vendored
2
.github/workflows/test-redis-enterprise.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
path: redis-ee
|
||||
|
||||
- name: Set up ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -9,3 +9,7 @@ coverage.txt
|
||||
**/coverage.txt
|
||||
.vscode
|
||||
tmp/*
|
||||
*.test
|
||||
|
||||
# maintenanceNotifications upgrade documentation (temporary)
|
||||
maintenanceNotifications/docs/
|
||||
|
||||
119
RELEASE-NOTES.md
119
RELEASE-NOTES.md
@@ -1,5 +1,124 @@
|
||||
# Release Notes
|
||||
|
||||
# 9.15.0-beta.3 (2025-09-26)
|
||||
|
||||
## Highlights
|
||||
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
|
||||
|
||||
# Changes
|
||||
|
||||
- chore: Update hash_commands.go ([#3523](https://github.com/redis/go-redis/pull/3523))
|
||||
|
||||
## 🚀 New Features
|
||||
|
||||
- feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- fix: pipeline repeatedly sets the error ([#3525](https://github.com/redis/go-redis/pull/3525))
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
|
||||
- feat(e2e-testing): maintnotifications e2e and refactor ([#3526](https://github.com/redis/go-redis/pull/3526))
|
||||
- feat(tag.sh): Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@cxljs](https://github.com/cxljs), [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), and [@omid-h70](https://github.com/omid-h70)
|
||||
|
||||
|
||||
# 9.15.0-beta.1 (2025-09-10)
|
||||
|
||||
## Highlights
|
||||
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
|
||||
|
||||
### Hitless Upgrades
|
||||
Hitless upgrades is a major new feature that allows for zero-downtime upgrades in Redis clusters.
|
||||
You can find more information in the [Hitless Upgrades documentation](https://github.com/redis/go-redis/tree/master/hitless).
|
||||
|
||||
# Changes
|
||||
|
||||
## 🚀 New Features
|
||||
- [CAE-1088] & [CAE-1072] feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
|
||||
|
||||
# 9.14.0 (2025-09-10)
|
||||
|
||||
## Highlights
|
||||
- Added batch process method to the pipeline ([#3510](https://github.com/redis/go-redis/pull/3510))
|
||||
|
||||
# Changes
|
||||
|
||||
## 🚀 New Features
|
||||
|
||||
- Added batch process method to the pipeline ([#3510](https://github.com/redis/go-redis/pull/3510))
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- fix: SetErr on Cmd if the command cannot be queued correctly in multi/exec ([#3509](https://github.com/redis/go-redis/pull/3509))
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- Updates release drafter config to exclude dependabot ([#3511](https://github.com/redis/go-redis/pull/3511))
|
||||
- chore(deps): bump actions/setup-go from 5 to 6 ([#3504](https://github.com/redis/go-redis/pull/3504))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@elena-kolevska](https://github.com/elena-kolevksa), [@htemelski-redis](https://github.com/htemelski-redis) and [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
|
||||
# 9.13.0 (2025-09-03)
|
||||
|
||||
## Highlights
|
||||
- Pipeliner expose queued commands ([#3496](https://github.com/redis/go-redis/pull/3496))
|
||||
- Ensure that JSON.GET returns Nil response ([#3470](https://github.com/redis/go-redis/pull/3470))
|
||||
- Fixes on Read and Write buffer sizes and UniversalOptions
|
||||
|
||||
## Changes
|
||||
- Pipeliner expose queued commands ([#3496](https://github.com/redis/go-redis/pull/3496))
|
||||
- fix(test): fix a timing issue in pubsub test ([#3498](https://github.com/redis/go-redis/pull/3498))
|
||||
- Allow users to enable read-write splitting in failover mode. ([#3482](https://github.com/redis/go-redis/pull/3482))
|
||||
- Set the read/write buffer size of the sentinel client to 4KiB ([#3476](https://github.com/redis/go-redis/pull/3476))
|
||||
|
||||
## 🚀 New Features
|
||||
|
||||
- fix(otel): register wait metrics ([#3499](https://github.com/redis/go-redis/pull/3499))
|
||||
- Support subscriptions against cluster slave nodes ([#3480](https://github.com/redis/go-redis/pull/3480))
|
||||
- Add wait metrics to otel ([#3493](https://github.com/redis/go-redis/pull/3493))
|
||||
- Clean failing timeout implementation ([#3472](https://github.com/redis/go-redis/pull/3472))
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- Do not assume that all non-IP hosts are loopbacks ([#3085](https://github.com/redis/go-redis/pull/3085))
|
||||
- Ensure that JSON.GET returns Nil response ([#3470](https://github.com/redis/go-redis/pull/3470))
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- fix(otel): register wait metrics ([#3499](https://github.com/redis/go-redis/pull/3499))
|
||||
- fix(make test): Add default env in makefile ([#3491](https://github.com/redis/go-redis/pull/3491))
|
||||
- Update the introduction to running tests in README.md ([#3495](https://github.com/redis/go-redis/pull/3495))
|
||||
- test: Add comprehensive edge case tests for IncrByFloat command ([#3477](https://github.com/redis/go-redis/pull/3477))
|
||||
- Set the default read/write buffer size of Redis connection to 32KiB ([#3483](https://github.com/redis/go-redis/pull/3483))
|
||||
- Bumps test image to 8.2.1-pre ([#3478](https://github.com/redis/go-redis/pull/3478))
|
||||
- fix UniversalOptions miss ReadBufferSize and WriteBufferSize options ([#3485](https://github.com/redis/go-redis/pull/3485))
|
||||
- chore(deps): bump actions/checkout from 4 to 5 ([#3484](https://github.com/redis/go-redis/pull/3484))
|
||||
- Removes dry run for stale issues policy ([#3471](https://github.com/redis/go-redis/pull/3471))
|
||||
- Update otel metrics URL ([#3474](https://github.com/redis/go-redis/pull/3474))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@LINKIWI](https://github.com/LINKIWI), [@cxljs](https://github.com/cxljs), [@cybersmeashish](https://github.com/cybersmeashish), [@elena-kolevska](https://github.com/elena-kolevska), [@htemelski-redis](https://github.com/htemelski-redis), [@mwhooker](https://github.com/mwhooker), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@suever](https://github.com/suever)
|
||||
|
||||
|
||||
# 9.12.1 (2025-08-11)
|
||||
## 🚀 Highlights
|
||||
In the last version (9.12.0) the client introduced bigger write and read buffer sized. The default value we set was 512KiB.
|
||||
|
||||
111
adapters.go
Normal file
111
adapters.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
|
||||
var ErrInvalidCommand = errors.New("invalid command type")
|
||||
|
||||
// ErrInvalidPool is returned when the pool type is not supported.
|
||||
var ErrInvalidPool = errors.New("invalid pool type")
|
||||
|
||||
// newClientAdapter creates a new client adapter for regular Redis clients.
|
||||
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
|
||||
return &clientAdapter{client: client}
|
||||
}
|
||||
|
||||
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
|
||||
type clientAdapter struct {
|
||||
client *baseClient
|
||||
}
|
||||
|
||||
// GetOptions returns the client options.
|
||||
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
|
||||
return &optionsAdapter{options: ca.client.opt}
|
||||
}
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
|
||||
}
|
||||
|
||||
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
|
||||
type optionsAdapter struct {
|
||||
options *Options
|
||||
}
|
||||
|
||||
// GetReadTimeout returns the read timeout.
|
||||
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
|
||||
return oa.options.ReadTimeout
|
||||
}
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
|
||||
return oa.options.WriteTimeout
|
||||
}
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
func (oa *optionsAdapter) GetNetwork() string {
|
||||
return oa.options.Network
|
||||
}
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
func (oa *optionsAdapter) GetAddr() string {
|
||||
return oa.options.Addr
|
||||
}
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
func (oa *optionsAdapter) IsTLSEnabled() bool {
|
||||
return oa.options.TLSConfig != nil
|
||||
}
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
func (oa *optionsAdapter) GetProtocol() int {
|
||||
return oa.options.Protocol
|
||||
}
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
func (oa *optionsAdapter) GetPoolSize() int {
|
||||
return oa.options.PoolSize
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
baseDialer := oa.options.NewDialer()
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Extract network and address from the options
|
||||
network := oa.options.Network
|
||||
addr := oa.options.Addr
|
||||
return baseDialer(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
|
||||
type pushProcessorAdapter struct {
|
||||
processor push.NotificationProcessor
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
|
||||
if pushHandler, ok := handler.(push.NotificationHandler); ok {
|
||||
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
|
||||
}
|
||||
return errors.New("handler must implement push.NotificationHandler")
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
|
||||
return ppa.processor.UnregisterHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
|
||||
return ppa.processor.GetHandler(pushNotificationName)
|
||||
}
|
||||
353
async_handoff_integration_test.go
Normal file
353
async_handoff_integration_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
// mockNetConn implements net.Conn for testing
|
||||
type mockNetConn struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (m *mockNetConn) Close() error { return nil }
|
||||
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
|
||||
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
|
||||
func TestEventDrivenHandoffIntegration(t *testing.T) {
|
||||
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
|
||||
// Create a base dialer for testing
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
// Create processor with event-driven handoff support
|
||||
processor := maintnotifications.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a test pool with hooks
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
PoolSize: int32(5),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
defer testPool.Close()
|
||||
|
||||
// Set the pool reference in the processor for connection removal on handoff failure
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection and mark it for handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
// Set initialization function with a small delay to ensure handoff is pending
|
||||
initConnCalled := false
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||
initConnCalled = true
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
// Mark connection for handoff
|
||||
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Return connection to pool - this should queue handoff
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Give the on-demand worker a moment to start processing
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify handoff was queued
|
||||
if !processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should be queued in pending map")
|
||||
}
|
||||
|
||||
// Try to get the same connection - should be skipped due to pending handoff
|
||||
conn2, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get second connection: %v", err)
|
||||
}
|
||||
|
||||
// Should get a different connection (the pending one should be skipped)
|
||||
if conn == conn2 {
|
||||
t.Error("Should have gotten a different connection while handoff is pending")
|
||||
}
|
||||
|
||||
// Return the second connection
|
||||
testPool.Put(ctx, conn2)
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify handoff completed (removed from pending map)
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should have completed and been removed from pending map")
|
||||
}
|
||||
|
||||
if !initConnCalled {
|
||||
t.Error("InitConn should have been called during handoff")
|
||||
}
|
||||
|
||||
// Now the original connection should be available again
|
||||
conn3, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get third connection: %v", err)
|
||||
}
|
||||
|
||||
// Could be the original connection (now handed off) or a new one
|
||||
testPool.Put(ctx, conn3)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentHandoffs", func(t *testing.T) {
|
||||
// Create a base dialer that simulates slow handoffs
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate network delay
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := maintnotifications.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple concurrent handoffs
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Get connection
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get conn[%d]: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set initialization function
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
// Mark for handoff
|
||||
conn.MarkForHandoff("new-endpoint:6379", int64(id))
|
||||
|
||||
// Return to pool (starts async handoff)
|
||||
testPool.Put(ctx, conn)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Wait for all handoffs to complete
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Verify pool is still functional
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
|
||||
}
|
||||
testPool.Put(ctx, conn)
|
||||
})
|
||||
|
||||
t.Run("HandoffFailureRecovery", func(t *testing.T) {
|
||||
// Create a failing base dialer
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
|
||||
}
|
||||
|
||||
processor := maintnotifications.NewPoolHook(failingDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(3),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get connection and mark for handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
|
||||
|
||||
// Return to pool (starts async handoff that will fail)
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Wait for handoff to fail
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Connection should be removed from pending map after failed handoff
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Connection should be removed from pending map after failed handoff")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
conn2, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Pool should still be functional: %v", err)
|
||||
}
|
||||
|
||||
// In event-driven approach, the original connection remains in pool
|
||||
// even after failed handoff (it's still a valid connection)
|
||||
// We might get the same connection or a different one
|
||||
testPool.Put(ctx, conn2)
|
||||
})
|
||||
|
||||
t.Run("GracefulShutdown", func(t *testing.T) {
|
||||
// Create a slow base dialer
|
||||
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := maintnotifications.NewPoolHook(slowDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(2),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function with delay to ensure handoff is pending
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||
return nil
|
||||
})
|
||||
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Give the on-demand worker a moment to start and begin processing
|
||||
// The handoff should be pending because the slowDialer takes 100ms
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify handoff was queued and is being processed
|
||||
if !processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should be queued in pending map")
|
||||
}
|
||||
|
||||
// Give the handoff a moment to start processing
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown processor gracefully
|
||||
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = processor.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
t.Errorf("Graceful shutdown should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Handoff should have completed (removed from pending map)
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should have completed and been removed from pending map after shutdown")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
logging.Disable()
|
||||
}
|
||||
@@ -1,316 +0,0 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
)
|
||||
|
||||
var ctx = context.TODO()
|
||||
|
||||
type ClientStub struct {
|
||||
Cmdable
|
||||
resp []byte
|
||||
}
|
||||
|
||||
var initHello = []byte("%1\r\n+proto\r\n:3\r\n")
|
||||
|
||||
func NewClientStub(resp []byte) *ClientStub {
|
||||
stub := &ClientStub{
|
||||
resp: resp,
|
||||
}
|
||||
|
||||
stub.Cmdable = NewClient(&Options{
|
||||
PoolSize: 128,
|
||||
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return stub.stubConn(initHello), nil
|
||||
},
|
||||
DisableIdentity: true,
|
||||
})
|
||||
return stub
|
||||
}
|
||||
|
||||
func NewClusterClientStub(resp []byte) *ClientStub {
|
||||
stub := &ClientStub{
|
||||
resp: resp,
|
||||
}
|
||||
|
||||
client := NewClusterClient(&ClusterOptions{
|
||||
PoolSize: 128,
|
||||
Addrs: []string{":6379"},
|
||||
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return stub.stubConn(initHello), nil
|
||||
},
|
||||
DisableIdentity: true,
|
||||
|
||||
ClusterSlots: func(_ context.Context) ([]ClusterSlot, error) {
|
||||
return []ClusterSlot{
|
||||
{
|
||||
Start: 0,
|
||||
End: 16383,
|
||||
Nodes: []ClusterNode{{Addr: "127.0.0.1:6379"}},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
stub.Cmdable = client
|
||||
return stub
|
||||
}
|
||||
|
||||
func (c *ClientStub) stubConn(init []byte) *ConnStub {
|
||||
return &ConnStub{
|
||||
init: init,
|
||||
resp: c.resp,
|
||||
}
|
||||
}
|
||||
|
||||
type ConnStub struct {
|
||||
init []byte
|
||||
resp []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (c *ConnStub) Read(b []byte) (n int, err error) {
|
||||
// Return conn.init()
|
||||
if len(c.init) > 0 {
|
||||
n = copy(b, c.init)
|
||||
c.init = c.init[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if len(c.resp) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if c.pos >= len(c.resp) {
|
||||
c.pos = 0
|
||||
}
|
||||
n = copy(b, c.resp[c.pos:])
|
||||
c.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *ConnStub) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *ConnStub) Close() error { return nil }
|
||||
func (c *ConnStub) LocalAddr() net.Addr { return nil }
|
||||
func (c *ConnStub) RemoteAddr() net.Addr { return nil }
|
||||
func (c *ConnStub) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c *ConnStub) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c *ConnStub) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
type ClientStubFunc func([]byte) *ClientStub
|
||||
|
||||
func BenchmarkDecode(b *testing.B) {
|
||||
type Benchmark struct {
|
||||
name string
|
||||
stub ClientStubFunc
|
||||
}
|
||||
|
||||
benchmarks := []Benchmark{
|
||||
{"server", NewClientStub},
|
||||
{"cluster", NewClusterClientStub},
|
||||
}
|
||||
|
||||
for _, bench := range benchmarks {
|
||||
b.Run(fmt.Sprintf("RespError-%s", bench.name), func(b *testing.B) {
|
||||
respError(b, bench.stub)
|
||||
})
|
||||
b.Run(fmt.Sprintf("RespStatus-%s", bench.name), func(b *testing.B) {
|
||||
respStatus(b, bench.stub)
|
||||
})
|
||||
b.Run(fmt.Sprintf("RespInt-%s", bench.name), func(b *testing.B) {
|
||||
respInt(b, bench.stub)
|
||||
})
|
||||
b.Run(fmt.Sprintf("RespString-%s", bench.name), func(b *testing.B) {
|
||||
respString(b, bench.stub)
|
||||
})
|
||||
b.Run(fmt.Sprintf("RespArray-%s", bench.name), func(b *testing.B) {
|
||||
respArray(b, bench.stub)
|
||||
})
|
||||
b.Run(fmt.Sprintf("RespPipeline-%s", bench.name), func(b *testing.B) {
|
||||
respPipeline(b, bench.stub)
|
||||
})
|
||||
b.Run(fmt.Sprintf("RespTxPipeline-%s", bench.name), func(b *testing.B) {
|
||||
respTxPipeline(b, bench.stub)
|
||||
})
|
||||
|
||||
// goroutine
|
||||
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
|
||||
dynamicGoroutine(b, bench.stub, 5)
|
||||
})
|
||||
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
|
||||
dynamicGoroutine(b, bench.stub, 20)
|
||||
})
|
||||
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
|
||||
dynamicGoroutine(b, bench.stub, 50)
|
||||
})
|
||||
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
|
||||
dynamicGoroutine(b, bench.stub, 100)
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
|
||||
staticGoroutine(b, bench.stub, 5)
|
||||
})
|
||||
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
|
||||
staticGoroutine(b, bench.stub, 20)
|
||||
})
|
||||
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
|
||||
staticGoroutine(b, bench.stub, 50)
|
||||
})
|
||||
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
|
||||
staticGoroutine(b, bench.stub, 100)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func respError(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte("-ERR test error\r\n"))
|
||||
respErr := proto.RedisError("ERR test error")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := rdb.Get(ctx, "key").Err(); err != respErr {
|
||||
b.Fatalf("response error, got %q, want %q", err, respErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respStatus(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte("+OK\r\n"))
|
||||
var val string
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if val = rdb.Set(ctx, "key", "value", 0).Val(); val != "OK" {
|
||||
b.Fatalf("response error, got %q, want OK", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respInt(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte(":10\r\n"))
|
||||
var val int64
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if val = rdb.Incr(ctx, "key").Val(); val != 10 {
|
||||
b.Fatalf("response error, got %q, want 10", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respString(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte("$5\r\nhello\r\n"))
|
||||
var val string
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if val = rdb.Get(ctx, "key").Val(); val != "hello" {
|
||||
b.Fatalf("response error, got %q, want hello", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respArray(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte("*3\r\n$5\r\nhello\r\n:10\r\n+OK\r\n"))
|
||||
var val []interface{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if val = rdb.MGet(ctx, "key").Val(); len(val) != 3 {
|
||||
b.Fatalf("response error, got len(%d), want len(3)", len(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respPipeline(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte("+OK\r\n$5\r\nhello\r\n:1\r\n"))
|
||||
var pipe Pipeliner
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pipe = rdb.Pipeline()
|
||||
set := pipe.Set(ctx, "key", "value", 0)
|
||||
get := pipe.Get(ctx, "key")
|
||||
del := pipe.Del(ctx, "key")
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
b.Fatalf("response error, got %q, want nil", err)
|
||||
}
|
||||
if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
|
||||
b.Fatal("response error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respTxPipeline(b *testing.B, stub ClientStubFunc) {
|
||||
rdb := stub([]byte("+OK\r\n+QUEUED\r\n+QUEUED\r\n+QUEUED\r\n*3\r\n+OK\r\n$5\r\nhello\r\n:1\r\n"))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var set *StatusCmd
|
||||
var get *StringCmd
|
||||
var del *IntCmd
|
||||
_, err := rdb.TxPipelined(ctx, func(pipe Pipeliner) error {
|
||||
set = pipe.Set(ctx, "key", "value", 0)
|
||||
get = pipe.Get(ctx, "key")
|
||||
del = pipe.Del(ctx, "key")
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("response error, got %q, want nil", err)
|
||||
}
|
||||
if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
|
||||
b.Fatal("response error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dynamicGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
|
||||
rdb := stub([]byte("$5\r\nhello\r\n"))
|
||||
c := make(chan struct{}, concurrency)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c <- struct{}{}
|
||||
go func() {
|
||||
if val := rdb.Get(ctx, "key").Val(); val != "hello" {
|
||||
panic(fmt.Sprintf("response error, got %q, want hello", val))
|
||||
}
|
||||
<-c
|
||||
}()
|
||||
}
|
||||
// Here no longer wait for all goroutines to complete, it will not affect the test results.
|
||||
close(c)
|
||||
}
|
||||
|
||||
func staticGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
|
||||
rdb := stub([]byte("$5\r\nhello\r\n"))
|
||||
c := make(chan struct{}, concurrency)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
for {
|
||||
_, ok := <-c
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if val := rdb.Get(ctx, "key").Val(); val != "hello" {
|
||||
panic(fmt.Sprintf("response error, got %q, want hello", val))
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
c <- struct{}{}
|
||||
}
|
||||
close(c)
|
||||
}
|
||||
18
commands.go
18
commands.go
@@ -193,6 +193,7 @@ type Cmdable interface {
|
||||
ClientID(ctx context.Context) *IntCmd
|
||||
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
||||
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
||||
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
|
||||
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
||||
ConfigResetStat(ctx context.Context) *StatusCmd
|
||||
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
||||
@@ -519,6 +520,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades.
|
||||
// When enabled, the client will receive push notifications about Redis maintenance events.
|
||||
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
|
||||
args := []interface{}{"client", "maint_notifications"}
|
||||
if enabled {
|
||||
if endpointType == "" {
|
||||
endpointType = "none"
|
||||
}
|
||||
args = append(args, "on", "moving-endpoint-type", endpointType)
|
||||
} else {
|
||||
args = append(args, "off")
|
||||
}
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
||||
|
||||
@@ -3019,7 +3019,8 @@ var _ = Describe("Commands", func() {
|
||||
|
||||
res, err = client.HPTTL(ctx, "myhash", "key1", "key2", "key200").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res[0]).To(BeNumerically("~", 10*time.Second.Milliseconds(), 1))
|
||||
// overhead of the push notification check is about 1-2ms for 100 commands
|
||||
Expect(res[0]).To(BeNumerically("~", 10*time.Second.Milliseconds(), 2))
|
||||
})
|
||||
|
||||
It("should HGETDEL", Label("hash", "HGETDEL"), func() {
|
||||
|
||||
@@ -5,7 +5,7 @@ go 1.18
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require (
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
go.uber.org/zap v1.24.0
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ go 1.18
|
||||
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require github.com/redis/go-redis/v9 v9.12.1
|
||||
require github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -6,7 +6,7 @@ replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -4,7 +4,7 @@ go 1.18
|
||||
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require github.com/redis/go-redis/v9 v9.12.1
|
||||
require github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -11,8 +11,8 @@ replace github.com/redis/go-redis/extra/redisotel/v9 => ../../extra/redisotel
|
||||
replace github.com/redis/go-redis/extra/rediscmd/v9 => ../../extra/rediscmd
|
||||
|
||||
require (
|
||||
github.com/redis/go-redis/extra/redisotel/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/extra/redisotel/v9 v9.16.0-beta.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
github.com/uptrace/uptrace-go v1.21.0
|
||||
go.opentelemetry.io/otel v1.22.0
|
||||
)
|
||||
@@ -25,7 +25,7 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.12.1 // indirect
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.16.0-beta.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/runtime v0.46.1 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v0.44.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 // indirect
|
||||
|
||||
12
example/pubsub/go.mod
Normal file
12
example/pubsub/go.mod
Normal file
@@ -0,0 +1,12 @@
|
||||
module github.com/redis/go-redis/example/pubsub
|
||||
|
||||
go 1.18
|
||||
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require github.com/redis/go-redis/v9 v9.11.0
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
)
|
||||
6
example/pubsub/go.sum
Normal file
6
example/pubsub/go.sum
Normal file
@@ -0,0 +1,6 @@
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
175
example/pubsub/main.go
Normal file
175
example/pubsub/main.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
var cntErrors atomic.Int64
|
||||
var cntSuccess atomic.Int64
|
||||
var startTime = time.Now()
|
||||
|
||||
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
|
||||
// It was used to find regressions in pool management in maintnotifications mode.
|
||||
// Please don't use it as a reference for how to use pubsub.
|
||||
func main() {
|
||||
startTime = time.Now()
|
||||
wg := &sync.WaitGroup{}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ":6379",
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP,
|
||||
HandoffTimeout: 10 * time.Second,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
PostHandoffRelaxedDuration: 10 * time.Second,
|
||||
},
|
||||
})
|
||||
_ = rdb.FlushDB(ctx).Err()
|
||||
maintnotificationsManager := rdb.GetMaintNotificationsManager()
|
||||
if maintnotificationsManager == nil {
|
||||
panic("maintnotifications manager is nil")
|
||||
}
|
||||
loggingHook := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
maintnotificationsManager.AddNotificationHook(loggingHook)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
|
||||
}
|
||||
}()
|
||||
err := rdb.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("published", rdb.Get(ctx, "published").Val())
|
||||
fmt.Println("received", rdb.Get(ctx, "received").Val())
|
||||
subCtx, cancelSubCtx := context.WithCancel(ctx)
|
||||
pubCtx, cancelPublishers := context.WithCancel(ctx)
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go subscribe(subCtx, rdb, "test", i, wg)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
cancelSubCtx()
|
||||
time.Sleep(time.Second)
|
||||
subCtx, cancelSubCtx = context.WithCancel(ctx)
|
||||
for i := 0; i < 10; i++ {
|
||||
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
|
||||
fmt.Println("incr error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
wg.Add(1)
|
||||
go floodThePool(pubCtx, rdb, wg)
|
||||
}
|
||||
|
||||
for i := 0; i < 500; i++ {
|
||||
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
|
||||
fmt.Println("incr error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go subscribe(subCtx, rdb, "test2", i, wg)
|
||||
}
|
||||
time.Sleep(120 * time.Second)
|
||||
fmt.Println("canceling publishers")
|
||||
cancelPublishers()
|
||||
time.Sleep(10 * time.Second)
|
||||
fmt.Println("canceling subscribers")
|
||||
cancelSubCtx()
|
||||
wg.Wait()
|
||||
published, err := rdb.Get(ctx, "published").Result()
|
||||
received, err := rdb.Get(ctx, "received").Result()
|
||||
publishers, err := rdb.Get(ctx, "publishers").Result()
|
||||
subscribers, err := rdb.Get(ctx, "subscribers").Result()
|
||||
fmt.Printf("publishers: %s\n", publishers)
|
||||
fmt.Printf("published: %s\n", published)
|
||||
fmt.Printf("subscribers: %s\n", subscribers)
|
||||
fmt.Printf("received: %s\n", received)
|
||||
publishedInt, err := rdb.Get(ctx, "published").Int()
|
||||
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
|
||||
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
fmt.Println("errors:", cntErrors.Load())
|
||||
fmt.Println("success:", cntSuccess.Load())
|
||||
fmt.Println("time:", time.Since(startTime))
|
||||
}
|
||||
|
||||
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := rdb.Publish(ctx, "test2", "hello").Err()
|
||||
if err != nil {
|
||||
if err.Error() != "context canceled" {
|
||||
log.Println("publish error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
err = rdb.Incr(ctx, "published").Err()
|
||||
if err != nil {
|
||||
if err.Error() != "context canceled" {
|
||||
log.Println("incr error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
}
|
||||
time.Sleep(10 * time.Nanosecond)
|
||||
}
|
||||
}
|
||||
|
||||
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
rec := rdb.Subscribe(ctx, topic)
|
||||
recChan := rec.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rec.Close()
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rec.Close()
|
||||
return
|
||||
case msg := <-recChan:
|
||||
err := rdb.Incr(ctx, "received").Err()
|
||||
if err != nil {
|
||||
if err.Error() != "context canceled" {
|
||||
log.Printf("%s\n", err.Error())
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
}
|
||||
_ = msg // Use the message to avoid unused variable warning
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@ go 1.18
|
||||
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require github.com/redis/go-redis/v9 v9.12.1
|
||||
require github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -6,7 +6,7 @@ replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -57,6 +57,8 @@ func Example_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[ping]>
|
||||
}
|
||||
|
||||
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// pipeline finished processing: [[ping] [ping]]
|
||||
}
|
||||
|
||||
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[watch foo]>
|
||||
// starting processing: <[ping]>
|
||||
// finished processing: <[ping]>
|
||||
|
||||
@@ -7,8 +7,8 @@ replace github.com/redis/go-redis/v9 => ../..
|
||||
replace github.com/redis/go-redis/extra/rediscmd/v9 => ../rediscmd
|
||||
|
||||
require (
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.16.0-beta.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
go.opencensus.io v0.24.0
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ require (
|
||||
)
|
||||
|
||||
retract (
|
||||
v9.7.2 // This version was accidentally released.
|
||||
v9.5.3 // This version was accidentally released.
|
||||
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
|
||||
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ replace github.com/redis/go-redis/v9 => ../..
|
||||
require (
|
||||
github.com/bsm/ginkgo/v2 v2.12.0
|
||||
github.com/bsm/gomega v1.27.10
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -16,6 +16,7 @@ require (
|
||||
)
|
||||
|
||||
retract (
|
||||
v9.7.2 // This version was accidentally released.
|
||||
v9.5.3 // This version was accidentally released.
|
||||
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
|
||||
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package redisotel
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
@@ -21,6 +24,7 @@ type config struct {
|
||||
|
||||
dbStmtEnabled bool
|
||||
callerEnabled bool
|
||||
filter func(cmd redis.Cmder) bool
|
||||
|
||||
// Metrics options.
|
||||
|
||||
@@ -124,6 +128,37 @@ func WithCallerEnabled(on bool) TracingOption {
|
||||
})
|
||||
}
|
||||
|
||||
// WithCommandFilter allows filtering of commands when tracing to omit commands that may have sensitive details like
|
||||
// passwords.
|
||||
func WithCommandFilter(filter func(cmd redis.Cmder) bool) TracingOption {
|
||||
return tracingOption(func(conf *config) {
|
||||
conf.filter = filter
|
||||
})
|
||||
}
|
||||
|
||||
func BasicCommandFilter(cmd redis.Cmder) bool {
|
||||
if strings.ToLower(cmd.Name()) == "auth" {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.ToLower(cmd.Name()) == "hello" {
|
||||
if len(cmd.Args()) < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
arg, exists := cmd.Args()[2].(string)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.ToLower(arg) == "auth" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type MetricsOption interface {
|
||||
|
||||
@@ -7,8 +7,8 @@ replace github.com/redis/go-redis/v9 => ../..
|
||||
replace github.com/redis/go-redis/extra/rediscmd/v9 => ../rediscmd
|
||||
|
||||
require (
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.16.0-beta.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
go.opentelemetry.io/otel v1.22.0
|
||||
go.opentelemetry.io/otel/metric v1.22.0
|
||||
go.opentelemetry.io/otel/sdk v1.22.0
|
||||
@@ -24,6 +24,7 @@ require (
|
||||
)
|
||||
|
||||
retract (
|
||||
v9.7.2 // This version was accidentally released.
|
||||
v9.5.3 // This version was accidentally released.
|
||||
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
|
||||
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
|
||||
)
|
||||
|
||||
|
||||
@@ -220,6 +220,8 @@ func reportPoolStats(rdb *redis.Client, conf *config) (metric.Registration, erro
|
||||
idleMin,
|
||||
connsMax,
|
||||
usage,
|
||||
waits,
|
||||
waitsDuration,
|
||||
timeouts,
|
||||
hits,
|
||||
misses,
|
||||
|
||||
@@ -102,6 +102,12 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook {
|
||||
func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
|
||||
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||
|
||||
// Check if the command should be filtered out
|
||||
if th.conf.filter != nil && th.conf.filter(cmd) {
|
||||
// If so, just call the next hook
|
||||
return hook(ctx, cmd)
|
||||
}
|
||||
|
||||
attrs := make([]attribute.KeyValue, 0, 8)
|
||||
if th.conf.callerEnabled {
|
||||
fn, file, line := funcFileLine("github.com/redis/go-redis")
|
||||
|
||||
@@ -95,6 +95,138 @@ func TestWithoutCaller(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCommandFilter(t *testing.T) {
|
||||
|
||||
t.Run("filter out ping command", func(t *testing.T) {
|
||||
provider := sdktrace.NewTracerProvider()
|
||||
hook := newTracingHook(
|
||||
"",
|
||||
WithTracerProvider(provider),
|
||||
WithCommandFilter(func(cmd redis.Cmder) bool {
|
||||
return cmd.Name() == "ping"
|
||||
}),
|
||||
)
|
||||
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
|
||||
cmd := redis.NewCmd(ctx, "ping")
|
||||
defer span.End()
|
||||
|
||||
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
|
||||
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
|
||||
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "ping" {
|
||||
t.Fatalf("ping command should not be traced")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
err := processHook(ctx, cmd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("do not filter ping command", func(t *testing.T) {
|
||||
provider := sdktrace.NewTracerProvider()
|
||||
hook := newTracingHook(
|
||||
"",
|
||||
WithTracerProvider(provider),
|
||||
WithCommandFilter(func(cmd redis.Cmder) bool {
|
||||
return false // never filter
|
||||
}),
|
||||
)
|
||||
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
|
||||
cmd := redis.NewCmd(ctx, "ping")
|
||||
defer span.End()
|
||||
|
||||
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
|
||||
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
|
||||
if innerSpan.Name() != "ping" {
|
||||
t.Fatalf("ping command should be traced")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
err := processHook(ctx, cmd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth command filtered with basic command filter", func(t *testing.T) {
|
||||
provider := sdktrace.NewTracerProvider()
|
||||
hook := newTracingHook(
|
||||
"",
|
||||
WithTracerProvider(provider),
|
||||
WithCommandFilter(BasicCommandFilter),
|
||||
)
|
||||
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
|
||||
cmd := redis.NewCmd(ctx, "auth", "test-password")
|
||||
defer span.End()
|
||||
|
||||
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
|
||||
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
|
||||
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "auth" {
|
||||
t.Fatalf("auth command should not be traced by default")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
err := processHook(ctx, cmd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hello command filtered with basic command filter when sensitive", func(t *testing.T) {
|
||||
provider := sdktrace.NewTracerProvider()
|
||||
hook := newTracingHook(
|
||||
"",
|
||||
WithTracerProvider(provider),
|
||||
WithCommandFilter(BasicCommandFilter),
|
||||
)
|
||||
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
|
||||
cmd := redis.NewCmd(ctx, "hello", 3, "AUTH", "test-user", "test-password")
|
||||
defer span.End()
|
||||
|
||||
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
|
||||
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
|
||||
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "hello" {
|
||||
t.Fatalf("auth command should not be traced by default")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
err := processHook(ctx, cmd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hello command not filtered with basic command filter when not sensitive", func(t *testing.T) {
|
||||
provider := sdktrace.NewTracerProvider()
|
||||
hook := newTracingHook(
|
||||
"",
|
||||
WithTracerProvider(provider),
|
||||
WithCommandFilter(BasicCommandFilter),
|
||||
)
|
||||
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
|
||||
cmd := redis.NewCmd(ctx, "hello", 3)
|
||||
defer span.End()
|
||||
|
||||
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
|
||||
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
|
||||
if innerSpan.Name() != "hello" {
|
||||
t.Fatalf("hello command should be traced")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
err := processHook(ctx, cmd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTracingHook_DialHook(t *testing.T) {
|
||||
imsb := tracetest.NewInMemoryExporter()
|
||||
provider := sdktrace.NewTracerProvider(sdktrace.WithSyncer(imsb))
|
||||
|
||||
@@ -6,7 +6,7 @@ replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require (
|
||||
github.com/prometheus/client_golang v1.14.0
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/redis/go-redis/v9 v9.16.0-beta.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -23,6 +23,7 @@ require (
|
||||
)
|
||||
|
||||
retract (
|
||||
v9.7.2 // This version was accidentally released.
|
||||
v9.5.3 // This version was accidentally released.
|
||||
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
|
||||
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
|
||||
)
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -10,6 +10,8 @@ require (
|
||||
)
|
||||
|
||||
retract (
|
||||
v9.15.1 // This version is used to retract v9.15.0
|
||||
v9.15.0 // This version was accidentally released. It is identical to 9.15.0-beta.2
|
||||
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
|
||||
v9.5.4 // This version was accidentally released. Please use version 9.6.0 instead.
|
||||
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
|
||||
|
||||
@@ -116,16 +116,16 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice
|
||||
|
||||
// HSet accepts values in following formats:
|
||||
//
|
||||
// - HSet("myhash", "key1", "value1", "key2", "value2")
|
||||
// - HSet(ctx, "myhash", "key1", "value1", "key2", "value2")
|
||||
//
|
||||
// - HSet("myhash", []string{"key1", "value1", "key2", "value2"})
|
||||
// - HSet(ctx, "myhash", []string{"key1", "value1", "key2", "value2"})
|
||||
//
|
||||
// - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
|
||||
// - HSet(ctx, "myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
|
||||
//
|
||||
// Playing struct With "redis" tag.
|
||||
// type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` }
|
||||
//
|
||||
// - HSet("myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
|
||||
// - HSet(ctx, "myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
|
||||
//
|
||||
// For struct, can be a structure pointer type, we only parse the field whose tag is redis.
|
||||
// if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it,
|
||||
|
||||
245
hset_benchmark_test.go
Normal file
245
hset_benchmark_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// HSET Benchmark Tests
|
||||
//
|
||||
// This file contains benchmark tests for Redis HSET operations with different scales:
|
||||
// 1, 10, 100, 1000, 10000, 100000 operations
|
||||
//
|
||||
// Prerequisites:
|
||||
// - Redis server running on localhost:6379
|
||||
// - No authentication required
|
||||
//
|
||||
// Usage:
|
||||
// go test -bench=BenchmarkHSET -v ./hset_benchmark_test.go
|
||||
// go test -bench=BenchmarkHSETPipelined -v ./hset_benchmark_test.go
|
||||
// go test -bench=. -v ./hset_benchmark_test.go # Run all benchmarks
|
||||
//
|
||||
// Example output:
|
||||
// BenchmarkHSET/HSET_1_operations-8 5000 250000 ns/op 1000000.00 ops/sec
|
||||
// BenchmarkHSET/HSET_100_operations-8 100 10000000 ns/op 100000.00 ops/sec
|
||||
//
|
||||
// The benchmarks test three different approaches:
|
||||
// 1. Individual HSET commands (BenchmarkHSET)
|
||||
// 2. Pipelined HSET commands (BenchmarkHSETPipelined)
|
||||
|
||||
// BenchmarkHSET benchmarks HSET operations with different scales
|
||||
func BenchmarkHSET(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup Redis client
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
// Test connection
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
b.Skipf("Redis server not available: %v", err)
|
||||
}
|
||||
|
||||
// Clean up before and after tests
|
||||
defer func() {
|
||||
rdb.FlushDB(ctx)
|
||||
}()
|
||||
|
||||
scales := []int{1, 10, 100, 1000, 10000, 100000}
|
||||
|
||||
for _, scale := range scales {
|
||||
b.Run(fmt.Sprintf("HSET_%d_operations", scale), func(b *testing.B) {
|
||||
benchmarkHSETOperations(b, rdb, ctx, scale)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkHSETOperations performs the actual HSET benchmark for a given scale
|
||||
func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
|
||||
hashKey := fmt.Sprintf("benchmark_hash_%d", operations)
|
||||
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
totalTimes := []time.Duration{}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
// Clean up the hash before each iteration
|
||||
rdb.Del(ctx, hashKey)
|
||||
b.StartTimer()
|
||||
|
||||
startTime := time.Now()
|
||||
// Perform the specified number of HSET operations
|
||||
for j := 0; j < operations; j++ {
|
||||
field := fmt.Sprintf("field_%d", j)
|
||||
value := fmt.Sprintf("value_%d", j)
|
||||
|
||||
err := rdb.HSet(ctx, hashKey, field, value).Err()
|
||||
if err != nil {
|
||||
b.Fatalf("HSET operation failed: %v", err)
|
||||
}
|
||||
}
|
||||
totalTimes = append(totalTimes, time.Now().Sub(startTime))
|
||||
}
|
||||
|
||||
// Stop the timer to calculate metrics
|
||||
b.StopTimer()
|
||||
|
||||
// Report operations per second
|
||||
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
|
||||
b.ReportMetric(opsPerSec, "ops/sec")
|
||||
|
||||
// Report average time per operation
|
||||
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
|
||||
b.ReportMetric(float64(avgTimePerOp), "ns/op")
|
||||
// report average time in milliseconds from totalTimes
|
||||
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
|
||||
b.ReportMetric(float64(avgTimePerOpMs), "ms")
|
||||
}
|
||||
|
||||
// BenchmarkHSETPipelined benchmarks HSET operations using pipelining for better performance
|
||||
func BenchmarkHSETPipelined(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup Redis client
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
// Test connection
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
b.Skipf("Redis server not available: %v", err)
|
||||
}
|
||||
|
||||
// Clean up before and after tests
|
||||
defer func() {
|
||||
rdb.FlushDB(ctx)
|
||||
}()
|
||||
|
||||
scales := []int{1, 10, 100, 1000, 10000, 100000}
|
||||
|
||||
for _, scale := range scales {
|
||||
b.Run(fmt.Sprintf("HSET_Pipelined_%d_operations", scale), func(b *testing.B) {
|
||||
benchmarkHSETPipelined(b, rdb, ctx, scale)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkHSETPipelined performs HSET benchmark using pipelining
|
||||
func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
|
||||
hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations)
|
||||
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
totalTimes := []time.Duration{}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
// Clean up the hash before each iteration
|
||||
rdb.Del(ctx, hashKey)
|
||||
b.StartTimer()
|
||||
|
||||
startTime := time.Now()
|
||||
// Use pipelining for better performance
|
||||
pipe := rdb.Pipeline()
|
||||
|
||||
// Add all HSET operations to the pipeline
|
||||
for j := 0; j < operations; j++ {
|
||||
field := fmt.Sprintf("field_%d", j)
|
||||
value := fmt.Sprintf("value_%d", j)
|
||||
pipe.HSet(ctx, hashKey, field, value)
|
||||
}
|
||||
|
||||
// Execute all operations at once
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
b.Fatalf("Pipeline execution failed: %v", err)
|
||||
}
|
||||
totalTimes = append(totalTimes, time.Now().Sub(startTime))
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
// Report operations per second
|
||||
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
|
||||
b.ReportMetric(opsPerSec, "ops/sec")
|
||||
|
||||
// Report average time per operation
|
||||
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
|
||||
b.ReportMetric(float64(avgTimePerOp), "ns/op")
|
||||
// report average time in milliseconds from totalTimes
|
||||
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
|
||||
b.ReportMetric(float64(avgTimePerOpMs), "ms")
|
||||
}
|
||||
|
||||
// add same tests but with RESP2
|
||||
func BenchmarkHSET_RESP2(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup Redis client
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "", // no password docs
|
||||
DB: 0, // use default DB
|
||||
Protocol: 2,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
// Test connection
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
b.Skipf("Redis server not available: %v", err)
|
||||
}
|
||||
|
||||
// Clean up before and after tests
|
||||
defer func() {
|
||||
rdb.FlushDB(ctx)
|
||||
}()
|
||||
|
||||
scales := []int{1, 10, 100, 1000, 10000, 100000}
|
||||
|
||||
for _, scale := range scales {
|
||||
b.Run(fmt.Sprintf("HSET_RESP2_%d_operations", scale), func(b *testing.B) {
|
||||
benchmarkHSETOperations(b, rdb, ctx, scale)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHSETPipelined_RESP2(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup Redis client
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "", // no password docs
|
||||
DB: 0, // use default DB
|
||||
Protocol: 2,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
// Test connection
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
b.Skipf("Redis server not available: %v", err)
|
||||
}
|
||||
|
||||
// Clean up before and after tests
|
||||
defer func() {
|
||||
rdb.FlushDB(ctx)
|
||||
}()
|
||||
|
||||
scales := []int{1, 10, 100, 1000, 10000, 100000}
|
||||
|
||||
for _, scale := range scales {
|
||||
b.Run(fmt.Sprintf("HSET_Pipelined_RESP2_%d_operations", scale), func(b *testing.B) {
|
||||
benchmarkHSETPipelined(b, rdb, ctx, scale)
|
||||
})
|
||||
}
|
||||
}
|
||||
54
internal/interfaces/interfaces.go
Normal file
54
internal/interfaces/interfaces.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Package interfaces provides shared interfaces used by both the main redis package
|
||||
// and the maintnotifications upgrade package to avoid circular dependencies.
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NotificationProcessor is (most probably) a push.NotificationProcessor
|
||||
// forward declaration to avoid circular imports
|
||||
type NotificationProcessor interface {
|
||||
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
|
||||
UnregisterHandler(pushNotificationName string) error
|
||||
GetHandler(pushNotificationName string) interface{}
|
||||
}
|
||||
|
||||
// ClientInterface defines the interface that clients must implement for maintnotifications upgrades.
|
||||
type ClientInterface interface {
|
||||
// GetOptions returns the client options.
|
||||
GetOptions() OptionsInterface
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
GetPushProcessor() NotificationProcessor
|
||||
}
|
||||
|
||||
// OptionsInterface defines the interface for client options.
|
||||
// Uses an adapter pattern to avoid circular dependencies.
|
||||
type OptionsInterface interface {
|
||||
// GetReadTimeout returns the read timeout.
|
||||
GetReadTimeout() time.Duration
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
GetWriteTimeout() time.Duration
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
GetNetwork() string
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
GetAddr() string
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
IsTLSEnabled() bool
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
GetProtocol() int
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
GetPoolSize() int
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
NewDialer() func(context.Context) (net.Conn, error)
|
||||
}
|
||||
@@ -7,20 +7,73 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// TODO (ned): Revisit logging
|
||||
// Add more standardized approach with log levels and configurability
|
||||
|
||||
type Logging interface {
|
||||
Printf(ctx context.Context, format string, v ...interface{})
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
type DefaultLogger struct {
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
_ = l.log.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func NewDefaultLogger() Logging {
|
||||
return &DefaultLogger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
}
|
||||
}
|
||||
|
||||
// Logger calls Output to print to the stderr.
|
||||
// Arguments are handled in the manner of fmt.Print.
|
||||
var Logger Logging = &logger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
var Logger Logging = NewDefaultLogger()
|
||||
|
||||
var LogLevel LogLevelT = LogLevelError
|
||||
|
||||
// LogLevelT represents the logging level
|
||||
type LogLevelT int
|
||||
|
||||
// Log level constants for the entire go-redis library
|
||||
const (
|
||||
LogLevelError LogLevelT = iota // 0 - errors only
|
||||
LogLevelWarn // 1 - warnings and errors
|
||||
LogLevelInfo // 2 - info, warnings, and errors
|
||||
LogLevelDebug // 3 - debug, info, warnings, and errors
|
||||
)
|
||||
|
||||
// String returns the string representation of the log level
|
||||
func (l LogLevelT) String() string {
|
||||
switch l {
|
||||
case LogLevelError:
|
||||
return "ERROR"
|
||||
case LogLevelWarn:
|
||||
return "WARN"
|
||||
case LogLevelInfo:
|
||||
return "INFO"
|
||||
case LogLevelDebug:
|
||||
return "DEBUG"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the log level is valid
|
||||
func (l LogLevelT) IsValid() bool {
|
||||
return l >= LogLevelError && l <= LogLevelDebug
|
||||
}
|
||||
|
||||
func (l LogLevelT) WarnOrAbove() bool {
|
||||
return l >= LogLevelWarn
|
||||
}
|
||||
|
||||
func (l LogLevelT) InfoOrAbove() bool {
|
||||
return l >= LogLevelInfo
|
||||
}
|
||||
|
||||
func (l LogLevelT) DebugOrAbove() bool {
|
||||
return l >= LogLevelDebug
|
||||
}
|
||||
|
||||
625
internal/maintnotifications/logs/log_messages.go
Normal file
625
internal/maintnotifications/logs/log_messages.go
Normal file
@@ -0,0 +1,625 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug
|
||||
func appendJSONIfDebug(message string, data map[string]interface{}) string {
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
jsonData, _ := json.Marshal(data)
|
||||
return fmt.Sprintf("%s %s", message, string(jsonData))
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
const (
|
||||
// ========================================
|
||||
// CIRCUIT_BREAKER.GO - Circuit breaker management
|
||||
// ========================================
|
||||
CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open"
|
||||
CircuitBreakerOpenedMessage = "circuit breaker opened"
|
||||
CircuitBreakerReopenedMessage = "circuit breaker reopened"
|
||||
CircuitBreakerClosedMessage = "circuit breaker closed"
|
||||
CircuitBreakerCleanupMessage = "circuit breaker cleanup"
|
||||
CircuitBreakerOpenMessage = "circuit breaker is open, failing fast"
|
||||
|
||||
// ========================================
|
||||
// CONFIG.GO - Configuration and debug
|
||||
// ========================================
|
||||
DebugLoggingEnabledMessage = "debug logging enabled"
|
||||
ConfigDebugMessage = "config debug"
|
||||
|
||||
// ========================================
|
||||
// ERRORS.GO - Error message constants
|
||||
// ========================================
|
||||
InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0"
|
||||
InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0"
|
||||
InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0"
|
||||
InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0"
|
||||
InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0"
|
||||
InvalidEndpointTypeErrorMessage = "invalid endpoint type"
|
||||
InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')"
|
||||
InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10"
|
||||
InvalidClientErrorMessage = "invalid client type"
|
||||
InvalidNotificationErrorMessage = "invalid notification format"
|
||||
MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached"
|
||||
HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration"
|
||||
InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1"
|
||||
InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0"
|
||||
InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1"
|
||||
ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff"
|
||||
ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff"
|
||||
ShutdownErrorMessage = "shutdown"
|
||||
CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast"
|
||||
|
||||
// ========================================
|
||||
// EXAMPLE_HOOKS.GO - Example metrics hooks
|
||||
// ========================================
|
||||
MetricsHookProcessingNotificationMessage = "metrics hook processing"
|
||||
MetricsHookRecordedErrorMessage = "metrics hook recorded error"
|
||||
|
||||
// ========================================
|
||||
// HANDOFF_WORKER.GO - Connection handoff processing
|
||||
// ========================================
|
||||
HandoffStartedMessage = "handoff started"
|
||||
HandoffFailedMessage = "handoff failed"
|
||||
ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries"
|
||||
ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff"
|
||||
HandoffRetryAttemptMessage = "Performing handoff"
|
||||
CannotQueueHandoffForRetryMessage = "can't queue handoff for retry"
|
||||
HandoffQueueFullMessage = "handoff queue is full"
|
||||
FailedToDialNewEndpointMessage = "failed to dial new endpoint"
|
||||
ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff"
|
||||
HandoffSuccessMessage = "handoff succeeded"
|
||||
RemovingConnectionFromPoolMessage = "removing connection from pool"
|
||||
NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it"
|
||||
WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown"
|
||||
WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request"
|
||||
WorkerPanicRecoveredMessage = "worker panic recovered"
|
||||
WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout"
|
||||
ReachedMaxHandoffRetriesMessage = "reached max handoff retries"
|
||||
|
||||
// ========================================
|
||||
// MANAGER.GO - Moving operation tracking and handler registration
|
||||
// ========================================
|
||||
DuplicateMovingOperationMessage = "duplicate MOVING operation ignored"
|
||||
TrackingMovingOperationMessage = "tracking MOVING operation"
|
||||
UntrackingMovingOperationMessage = "untracking MOVING operation"
|
||||
OperationNotTrackedMessage = "operation not tracked"
|
||||
FailedToRegisterHandlerMessage = "failed to register handler"
|
||||
|
||||
// ========================================
|
||||
// HOOKS.GO - Notification processing hooks
|
||||
// ========================================
|
||||
ProcessingNotificationMessage = "processing notification started"
|
||||
ProcessingNotificationFailedMessage = "proccessing notification failed"
|
||||
ProcessingNotificationSucceededMessage = "processing notification succeeded"
|
||||
|
||||
// ========================================
|
||||
// POOL_HOOK.GO - Pool connection management
|
||||
// ========================================
|
||||
FailedToQueueHandoffMessage = "failed to queue handoff"
|
||||
MarkedForHandoffMessage = "connection marked for handoff"
|
||||
|
||||
// ========================================
|
||||
// PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing
|
||||
// ========================================
|
||||
InvalidNotificationFormatMessage = "invalid notification format"
|
||||
InvalidNotificationTypeFormatMessage = "invalid notification type format"
|
||||
InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification"
|
||||
InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification"
|
||||
InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification"
|
||||
NoConnectionInHandlerContextMessage = "no connection in handler context"
|
||||
InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context"
|
||||
SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint"
|
||||
RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification"
|
||||
UnrelaxedTimeoutMessage = "clearing relaxed timeout"
|
||||
ManagerNotInitializedMessage = "manager not initialized"
|
||||
FailedToMarkForHandoffMessage = "failed to mark connection for handoff"
|
||||
|
||||
// ========================================
|
||||
// used in pool/conn
|
||||
// ========================================
|
||||
UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline"
|
||||
)
|
||||
|
||||
func HandoffStarted(connID uint64, newEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
"attempt": attempt,
|
||||
"maxAttempts": maxAttempts,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func HandoffSucceeded(connID uint64, newEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Timeout-related log functions
|
||||
func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func UnrelaxedTimeout(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func UnrelaxedTimeoutAfterDeadline(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
// Handoff queue and marking functions
|
||||
func HandoffQueueFull(queueLen, queueCap int) string {
|
||||
message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"queueLen": queueLen,
|
||||
"queueCap": queueCap,
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToQueueHandoff(connID uint64, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToMarkForHandoff(connID uint64, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"maxRetries": maxRetries,
|
||||
})
|
||||
}
|
||||
|
||||
// Notification processing functions
|
||||
func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
"notificationType": notificationType,
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
"error": err.Error(),
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessingNotificationSucceeded(connID uint64, notificationType string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
})
|
||||
}
|
||||
|
||||
// Moving operation tracking functions
|
||||
func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func UntrackingMovingOperation(connID uint64, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func OperationNotTracked(connID uint64, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
// Connection pool functions
|
||||
func RemovingConnectionFromPool(connID uint64, reason error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"reason": reason.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func NoPoolProvidedCannotRemove(connID uint64, reason error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"reason": reason.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Circuit breaker functions
|
||||
func CircuitBreakerOpen(connID uint64, endpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Additional handoff functions for specific cases
|
||||
func ConnectionNotMarkedForHandoff(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func ConnectionNotMarkedForHandoffError(connID uint64) string {
|
||||
return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage)
|
||||
}
|
||||
|
||||
func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"retries": retries,
|
||||
"newEndpoint": newEndpoint,
|
||||
"oldEndpoint": oldEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CannotQueueHandoffForRetry(err error) string {
|
||||
message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validation and error functions
|
||||
func InvalidNotificationFormat(notification interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidNotificationTypeFormat(notificationType interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": fmt.Sprintf("%v", notificationType),
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidNotification creates a log message for invalid notifications of any type
|
||||
func InvalidNotification(notificationType string, notification interface{}) string {
|
||||
message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidSeqIDInMovingNotification(seqID interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"seqID": fmt.Sprintf("%v", seqID),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidTimeSInMovingNotification(timeS interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"timeS": fmt.Sprintf("%v", timeS),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"newEndpoint": fmt.Sprintf("%v", newEndpoint),
|
||||
})
|
||||
}
|
||||
|
||||
func NoConnectionInHandlerContext(notificationType string) string {
|
||||
message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string {
|
||||
message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connType": fmt.Sprintf("%T", conn),
|
||||
})
|
||||
}
|
||||
|
||||
func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seconds": seconds,
|
||||
})
|
||||
}
|
||||
|
||||
func ManagerNotInitialized() string {
|
||||
return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func FailedToRegisterHandler(notificationType string, err error) string {
|
||||
message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ShutdownError() string {
|
||||
return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Configuration validation error functions
|
||||
func InvalidRelaxedTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffWorkersError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffQueueSizeError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidPostHandoffRelaxedDurationError() string {
|
||||
return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidEndpointTypeError() string {
|
||||
return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidMaintNotificationsError() string {
|
||||
return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffRetriesError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidClientError() string {
|
||||
return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidNotificationError() string {
|
||||
return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func MaxHandoffRetriesReachedError() string {
|
||||
return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func HandoffQueueFullError() string {
|
||||
return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerFailureThresholdError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerResetTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerMaxRequestsError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Configuration and debug functions
|
||||
func DebugLoggingEnabled() string {
|
||||
return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func ConfigDebug(config interface{}) string {
|
||||
message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"config": fmt.Sprintf("%+v", config),
|
||||
})
|
||||
}
|
||||
|
||||
// Handoff worker functions
|
||||
func WorkerExitingDueToShutdown() string {
|
||||
return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func WorkerExitingDueToShutdownWhileProcessing() string {
|
||||
return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func WorkerPanicRecovered(panicValue interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"panic": fmt.Sprintf("%v", panicValue),
|
||||
})
|
||||
}
|
||||
|
||||
func WorkerExitingDueToInactivityTimeout(timeout interface{}) string {
|
||||
message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
"until": until,
|
||||
})
|
||||
}
|
||||
|
||||
// Example hooks functions
|
||||
func MetricsHookProcessingNotification(notificationType string, connID uint64) string {
|
||||
message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func MetricsHookRecordedError(notificationType string, connID uint64, err error) string {
|
||||
message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Pool hook functions
|
||||
func MarkedForHandoff(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
// Circuit breaker additional functions
|
||||
func CircuitBreakerTransitioningToHalfOpen(endpoint string) string {
|
||||
message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerOpened(endpoint string, failures int64) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
"failures": failures,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerReopened(endpoint string) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerClosed(endpoint string, successes int64) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
"successes": successes,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerCleanup(removed int, total int) string {
|
||||
message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"removed": removed,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages
|
||||
// Returns a map containing the parsed key-value pairs from the structured data section
|
||||
// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}"
|
||||
// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"}
|
||||
func ExtractDataFromLogMessage(logMessage string) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Find the JSON data section at the end of the message
|
||||
re := regexp.MustCompile(`(\{.*\})$`)
|
||||
matches := re.FindStringSubmatch(logMessage)
|
||||
if len(matches) < 2 {
|
||||
return result
|
||||
}
|
||||
|
||||
jsonStr := matches[1]
|
||||
if jsonStr == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
// Parse the JSON directly
|
||||
var jsonResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil {
|
||||
return jsonResult
|
||||
}
|
||||
|
||||
// If JSON parsing fails, return empty map
|
||||
return result
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
|
||||
b.Run(bm.String(), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: bm.poolSize,
|
||||
PoolSize: int32(bm.poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
b.Run(bm.String(), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: bm.poolSize,
|
||||
PoolSize: int32(bm.poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
It("should use default buffer sizes when not specified", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
})
|
||||
|
||||
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
ReadBufferSize: customReadSize,
|
||||
WriteBufferSize: customWriteSize,
|
||||
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
It("should handle zero buffer sizes by using defaults", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
ReadBufferSize: 0, // Should use default
|
||||
WriteBufferSize: 0, // Should use default
|
||||
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
// without setting ReadBufferSize and WriteBufferSize
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
// ReadBufferSize and WriteBufferSize are not set (will be 0)
|
||||
})
|
||||
|
||||
@@ -3,26 +3,88 @@ package pool
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
)
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
// HandoffState represents the atomic state for connection handoffs
|
||||
// This struct is stored atomically to prevent race conditions between
|
||||
// checking handoff status and reading handoff parameters
|
||||
type HandoffState struct {
|
||||
ShouldHandoff bool // Whether connection should be handed off
|
||||
Endpoint string // New endpoint for handoff
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
}
|
||||
|
||||
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
|
||||
type atomicNetConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// generateConnID generates a fast unique identifier for a connection with zero allocations
|
||||
func generateConnID() uint64 {
|
||||
return atomic.AddUint64(&connIDCounter, 1)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
usedAt int64 // atomic
|
||||
netConn net.Conn
|
||||
usedAt int64 // atomic
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||
netConnAtomic atomic.Value // stores *atomicNetConn
|
||||
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
|
||||
Inited bool
|
||||
// Lightweight mutex to protect reader operations during handoff
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
Inited atomic.Bool
|
||||
pooled bool
|
||||
pubsub bool
|
||||
closed atomic.Bool
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
|
||||
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
|
||||
|
||||
// Counter to track multiple relaxed timeout setters if we have nested calls
|
||||
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
|
||||
// if counter reaches 0, we clear the relaxed timeouts
|
||||
relaxedCounter atomic.Int32
|
||||
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
// Connection identifier for unique tracking
|
||||
id uint64 // Unique numeric identifier for this connection
|
||||
|
||||
// Handoff state - using atomic operations for lock-free access
|
||||
usableAtomic atomic.Bool // Connection usability state
|
||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||
|
||||
// Atomic handoff state to prevent race conditions
|
||||
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
|
||||
handoffStateAtomic atomic.Value // stores *HandoffState
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
@@ -33,8 +95,8 @@ func NewConn(netConn net.Conn) *Conn {
|
||||
|
||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||
cn := &Conn{
|
||||
netConn: netConn,
|
||||
createdAt: time.Now(),
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
}
|
||||
|
||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||
@@ -50,6 +112,21 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
||||
}
|
||||
|
||||
// Store netConn atomically for lock-free access using wrapper
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
// Initialize atomic state
|
||||
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||
|
||||
// Initialize handoff state atomically
|
||||
initialHandoffState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
cn.handoffStateAtomic.Store(initialHandoffState)
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(time.Now())
|
||||
return cn
|
||||
@@ -64,23 +141,439 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
// This is the fast path for accessing netConn without mutex overhead.
|
||||
func (cn *Conn) getNetConn() net.Conn {
|
||||
if v := cn.netConnAtomic.Load(); v != nil {
|
||||
if wrapper, ok := v.(*atomicNetConn); ok {
|
||||
return wrapper.conn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setNetConn stores the network connection atomically (lock-free).
|
||||
// This is used for the fast path of connection replacement.
|
||||
func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// Lock-free helper methods for handoff state management
|
||||
|
||||
// isUsable returns true if the connection is safe to use (lock-free).
|
||||
func (cn *Conn) isUsable() bool {
|
||||
return cn.usableAtomic.Load()
|
||||
}
|
||||
|
||||
// setUsable sets the usable flag atomically (lock-free).
|
||||
func (cn *Conn) setUsable(usable bool) {
|
||||
cn.usableAtomic.Store(usable)
|
||||
}
|
||||
|
||||
// getHandoffState returns the current handoff state atomically (lock-free).
|
||||
func (cn *Conn) getHandoffState() *HandoffState {
|
||||
state := cn.handoffStateAtomic.Load()
|
||||
if state == nil {
|
||||
// Return default state if not initialized
|
||||
return &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
}
|
||||
return state.(*HandoffState)
|
||||
}
|
||||
|
||||
// setHandoffState sets the handoff state atomically (lock-free).
|
||||
func (cn *Conn) setHandoffState(state *HandoffState) {
|
||||
cn.handoffStateAtomic.Store(state)
|
||||
}
|
||||
|
||||
// shouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) shouldHandoff() bool {
|
||||
return cn.getHandoffState().ShouldHandoff
|
||||
}
|
||||
|
||||
// getMovingSeqID returns the sequence ID atomically (lock-free).
|
||||
func (cn *Conn) getMovingSeqID() int64 {
|
||||
return cn.getHandoffState().SeqID
|
||||
}
|
||||
|
||||
// getNewEndpoint returns the new endpoint atomically (lock-free).
|
||||
func (cn *Conn) getNewEndpoint() string {
|
||||
return cn.getHandoffState().Endpoint
|
||||
}
|
||||
|
||||
// setHandoffRetries sets the retry count atomically (lock-free).
|
||||
func (cn *Conn) setHandoffRetries(retries int) {
|
||||
cn.handoffRetriesAtomic.Store(uint32(retries))
|
||||
}
|
||||
|
||||
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
|
||||
func (cn *Conn) incrementHandoffRetries(delta int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
return cn.isUsable()
|
||||
}
|
||||
|
||||
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
|
||||
func (cn *Conn) IsPooled() bool {
|
||||
return cn.pooled
|
||||
}
|
||||
|
||||
// IsPubSub returns true if the connection is used for PubSub.
|
||||
func (cn *Conn) IsPubSub() bool {
|
||||
return cn.pubsub
|
||||
}
|
||||
|
||||
func (cn *Conn) IsInited() bool {
|
||||
return cn.Inited.Load()
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
cn.setUsable(usable)
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||
cn.relaxedCounter.Add(1)
|
||||
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||
}
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
|
||||
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
|
||||
}
|
||||
|
||||
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) ClearRelaxedTimeout() {
|
||||
// Atomically decrement counter and check if we should clear
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) {
|
||||
// Use atomic load to get current value for CAS to avoid stale value race
|
||||
current := cn.relaxedCounter.Load()
|
||||
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) clearRelaxedTimeout() {
|
||||
cn.relaxedReadTimeoutNs.Store(0)
|
||||
cn.relaxedWriteTimeoutNs.Store(0)
|
||||
cn.relaxedDeadlineNs.Store(0)
|
||||
cn.relaxedCounter.Store(0)
|
||||
}
|
||||
|
||||
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
|
||||
// This checks both the timeout values and the deadline (if set).
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) HasRelaxedTimeout() bool {
|
||||
// Fast path: no relaxed timeouts are set
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// If no relaxed timeouts are set, return false
|
||||
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, relaxed timeouts are active
|
||||
if deadlineNs == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If deadline is set, check if it's still in the future
|
||||
return time.Now().UnixNano() < deadlineNs
|
||||
}
|
||||
|
||||
// getEffectiveReadTimeout returns the timeout to use for read operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if readTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(readTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// getEffectiveWriteTimeout returns the timeout to use for write operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if writeTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(writeTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) SetOnClose(fn func() error) {
|
||||
cn.onClose = fn
|
||||
}
|
||||
|
||||
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
|
||||
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
|
||||
cn.initConnFunc = fn
|
||||
}
|
||||
|
||||
// ExecuteInitConn runs the stored connection initialization function if available.
|
||||
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
|
||||
if cn.initConnFunc != nil {
|
||||
return cn.initConnFunc(ctx, cn)
|
||||
}
|
||||
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||
cn.netConn = netConn
|
||||
// Store the new connection atomically first (lock-free)
|
||||
cn.setNetConn(netConn)
|
||||
// Protect reader reset operations to avoid data races
|
||||
// Use write lock since we're modifying the reader state
|
||||
cn.readerMu.Lock()
|
||||
cn.rd.Reset(netConn)
|
||||
cn.readerMu.Unlock()
|
||||
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
|
||||
// GetNetConn safely returns the current network connection using atomic load (lock-free).
|
||||
// This method is used by the pool for health checks and provides better performance.
|
||||
func (cn *Conn) GetNetConn() net.Conn {
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||
// New connection is not initialized yet
|
||||
cn.Inited.Store(false)
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
return cn.ExecuteInitConn(ctx)
|
||||
}
|
||||
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
|
||||
// Returns an error if the connection is already marked for handoff.
|
||||
// This method uses atomic compare-and-swap to ensure all handoff state is updated atomically.
|
||||
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||
const maxRetries = 50
|
||||
const baseDelay = time.Microsecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
currentState := cn.getHandoffState()
|
||||
|
||||
// Check if already marked for handoff
|
||||
if currentState.ShouldHandoff {
|
||||
return errors.New("connection is already marked for handoff")
|
||||
}
|
||||
|
||||
// Create new state with handoff enabled
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: true,
|
||||
Endpoint: newEndpoint,
|
||||
SeqID: seqID,
|
||||
}
|
||||
|
||||
// Atomic compare-and-swap to update entire state
|
||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If CAS failed, add exponential backoff to reduce contention
|
||||
if attempt < maxRetries-1 {
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to mark connection for handoff after %d attempts due to high contention", maxRetries)
|
||||
}
|
||||
|
||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
const maxRetries = 50
|
||||
const baseDelay = time.Microsecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
currentState := cn.getHandoffState()
|
||||
|
||||
// Check if marked for handoff
|
||||
if !currentState.ShouldHandoff {
|
||||
return errors.New("connection was not marked for handoff")
|
||||
}
|
||||
|
||||
// Create new state with handoff disabled (queued)
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: currentState.Endpoint, // Preserve endpoint for handoff processing
|
||||
SeqID: currentState.SeqID, // Preserve seqID for handoff processing
|
||||
}
|
||||
|
||||
// Atomic compare-and-swap to update state
|
||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
cn.setUsable(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If CAS failed, add exponential backoff to reduce contention
|
||||
// the delay will be 1, 2, 4... up to 512 microseconds
|
||||
if attempt < maxRetries-1 {
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
|
||||
}
|
||||
|
||||
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
return cn.shouldHandoff()
|
||||
}
|
||||
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
return cn.getNewEndpoint()
|
||||
}
|
||||
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
return cn.getMovingSeqID()
|
||||
}
|
||||
|
||||
// GetHandoffInfo returns all handoff information atomically (lock-free).
|
||||
// This method prevents race conditions by returning all handoff state in a single atomic operation.
|
||||
// Returns (shouldHandoff, endpoint, seqID).
|
||||
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
|
||||
state := cn.getHandoffState()
|
||||
return state.ShouldHandoff, state.Endpoint, state.SeqID
|
||||
}
|
||||
|
||||
// GetID returns the unique identifier for this connection.
|
||||
func (cn *Conn) GetID() uint64 {
|
||||
return cn.id
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
// Create clean state
|
||||
cleanState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
|
||||
// Atomically set clean state
|
||||
cn.setHandoffState(cleanState)
|
||||
cn.setHandoffRetries(0)
|
||||
cn.setUsable(true) // Connection is safe to use again after handoff completes
|
||||
}
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return cn.incrementHandoffRetries(n)
|
||||
}
|
||||
|
||||
// GetHandoffRetries returns the current handoff retry count (lock-free).
|
||||
func (cn *Conn) HandoffRetries() int {
|
||||
return int(cn.handoffRetriesAtomic.Load())
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) HasBufferedData() bool {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
return cn.rd.Buffered() > 0
|
||||
}
|
||||
|
||||
// PeekReplyTypeSafe safely peeks at the reply type.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
|
||||
if cn.rd.Buffered() <= 0 {
|
||||
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
|
||||
}
|
||||
return cn.rd.PeekReplyType()
|
||||
}
|
||||
|
||||
func (cn *Conn) Write(b []byte) (int, error) {
|
||||
return cn.netConn.Write(b)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Write(b)
|
||||
}
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) RemoteAddr() net.Addr {
|
||||
if cn.netConn != nil {
|
||||
return cn.netConn.RemoteAddr()
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -89,7 +582,16 @@ func (cn *Conn) WithReader(
|
||||
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
|
||||
|
||||
// Get the connection directly from atomic storage
|
||||
netConn := cn.getNetConn()
|
||||
if netConn == nil {
|
||||
return fmt.Errorf("redis: connection not available")
|
||||
}
|
||||
|
||||
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -100,13 +602,26 @@ func (cn *Conn) WithWriter(
|
||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
return err
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||
|
||||
// Always set write deadline, even if getNetConn() returns nil
|
||||
// This prevents write operations from hanging indefinitely
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// If getNetConn() returns nil, we still need to respect the timeout
|
||||
// Return an error to prevent indefinite blocking
|
||||
return fmt.Errorf("redis: conn[%d] not available for write operation", cn.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
if cn.bw.Buffered() > 0 {
|
||||
cn.bw.Reset(cn.netConn)
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fn(cn.wr); err != nil {
|
||||
@@ -116,12 +631,33 @@ func (cn *Conn) WithWriter(
|
||||
return cn.bw.Flush()
|
||||
}
|
||||
|
||||
func (cn *Conn) IsClosed() bool {
|
||||
return cn.closed.Load()
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
cn.closed.Store(true)
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
}
|
||||
return cn.netConn.Close()
|
||||
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
||||
// This is used to check if there are push notifications available
|
||||
// Important: This will work on Linux, but not on Windows
|
||||
func (cn *Conn) MaybeHasData() bool {
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return maybeHasData(netConn)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
// connCheck checks if the connection is still alive and if there is data in the socket
|
||||
// it will try to peek at the next byte without consuming it since we may want to work with it
|
||||
// later on (e.g. push notifications)
|
||||
func connCheck(conn net.Conn) error {
|
||||
// Reset previous timeout.
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
@@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error {
|
||||
|
||||
if err := rawConn.Read(func(fd uintptr) bool {
|
||||
var buf [1]byte
|
||||
n, err := syscall.Read(int(fd), buf[:])
|
||||
// Use MSG_PEEK to peek at data without consuming it
|
||||
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
|
||||
|
||||
switch {
|
||||
case n == 0 && err == nil:
|
||||
sysErr = io.EOF
|
||||
@@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error {
|
||||
|
||||
return sysErr
|
||||
}
|
||||
|
||||
// maybeHasData checks if there is data in the socket without consuming it
|
||||
func maybeHasData(conn net.Conn) bool {
|
||||
return connCheck(conn) == errUnexpectedRead
|
||||
}
|
||||
|
||||
@@ -2,8 +2,19 @@
|
||||
|
||||
package pool
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func connCheck(conn net.Conn) error {
|
||||
// errUnexpectedRead is placeholder error variable for non-unix build constraints
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
func connCheck(_ net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// since we can't check for data on the socket, we just assume there is some
|
||||
func maybeHasData(_ net.Conn) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
92
internal/pool/conn_relaxed_timeout_test.go
Normal file
92
internal/pool/conn_relaxed_timeout_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout
|
||||
func TestConcurrentRelaxedTimeoutClearing(t *testing.T) {
|
||||
// Create a dummy connection for testing
|
||||
netConn := &net.TCPConn{}
|
||||
cn := NewConn(netConn)
|
||||
defer cn.Close()
|
||||
|
||||
// Set relaxed timeout multiple times to increase counter
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
|
||||
// Verify counter is 3
|
||||
if count := cn.relaxedCounter.Load(); count != 3 {
|
||||
t.Errorf("Expected relaxed counter to be 3, got %d", count)
|
||||
}
|
||||
|
||||
// Clear timeouts concurrently to test race condition fix
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cn.ClearRelaxedTimeout()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify counter is 0 and timeouts are cleared
|
||||
if count := cn.relaxedCounter.Load(); count != 0 {
|
||||
t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count)
|
||||
}
|
||||
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout)
|
||||
}
|
||||
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario
|
||||
func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) {
|
||||
netConn := &net.TCPConn{}
|
||||
cn := NewConn(netConn)
|
||||
defer cn.Close()
|
||||
|
||||
// Set relaxed timeout once
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
|
||||
// Verify counter is 1
|
||||
if count := cn.relaxedCounter.Load(); count != 1 {
|
||||
t.Errorf("Expected relaxed counter to be 1, got %d", count)
|
||||
}
|
||||
|
||||
// Test concurrent clearing with race condition scenario
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Multiple goroutines try to clear simultaneously
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cn.ClearRelaxedTimeout()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify final state is consistent
|
||||
if count := cn.relaxedCounter.Load(); count != 0 {
|
||||
t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count)
|
||||
}
|
||||
|
||||
// Verify timeouts are actually cleared
|
||||
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout)
|
||||
}
|
||||
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout)
|
||||
}
|
||||
if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 {
|
||||
t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline)
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
|
||||
}
|
||||
|
||||
func (cn *Conn) NetConn() net.Conn {
|
||||
return cn.netConn
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
func (p *ConnPool) CheckMinIdleConns() {
|
||||
|
||||
114
internal/pool/hooks.go
Normal file
114
internal/pool/hooks.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolHook defines the interface for connection lifecycle hooks.
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||
}
|
||||
|
||||
// PoolHookManager manages multiple pool hooks.
|
||||
type PoolHookManager struct {
|
||||
hooks []PoolHook
|
||||
hooksMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolHookManager creates a new pool hook manager.
|
||||
func NewPoolHookManager() *PoolHookManager {
|
||||
return &PoolHookManager{
|
||||
hooks: make([]PoolHook, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddHook adds a pool hook to the manager.
|
||||
// Hooks are called in the order they were added.
|
||||
func (phm *PoolHookManager) AddHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
phm.hooks = append(phm.hooks, hook)
|
||||
}
|
||||
|
||||
// RemoveHook removes a pool hook from the manager.
|
||||
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
|
||||
for i, h := range phm.hooks {
|
||||
if h == hook {
|
||||
// Remove hook by swapping with last element and truncating
|
||||
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
|
||||
phm.hooks = phm.hooks[:len(phm.hooks)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
shouldPool = true // Default to pooling the connection
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||
|
||||
if hookErr != nil {
|
||||
return false, true, hookErr
|
||||
}
|
||||
|
||||
// If any hook says to remove or not pool, respect that decision
|
||||
if hookShouldRemove {
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
if !hookShouldPool {
|
||||
shouldPool = false
|
||||
}
|
||||
}
|
||||
|
||||
return shouldPool, false, nil
|
||||
}
|
||||
|
||||
// GetHookCount returns the number of registered hooks (for testing).
|
||||
func (phm *PoolHookManager) GetHookCount() int {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
return len(phm.hooks)
|
||||
}
|
||||
|
||||
// GetHooks returns a copy of all registered hooks.
|
||||
func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
hooks := make([]PoolHook, len(phm.hooks))
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
||||
213
internal/pool/hooks_test.go
Normal file
213
internal/pool/hooks_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestHook for testing hook functionality
|
||||
type TestHook struct {
|
||||
OnGetCalled int
|
||||
OnPutCalled int
|
||||
GetError error
|
||||
PutError error
|
||||
ShouldPool bool
|
||||
ShouldRemove bool
|
||||
}
|
||||
|
||||
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
th.OnGetCalled++
|
||||
return th.GetError
|
||||
}
|
||||
|
||||
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
th.OnPutCalled++
|
||||
return th.ShouldPool, th.ShouldRemove, th.PutError
|
||||
}
|
||||
|
||||
func TestPoolHookManager(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Test initial state
|
||||
if manager.GetHookCount() != 0 {
|
||||
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Add hooks
|
||||
hook1 := &TestHook{ShouldPool: true}
|
||||
hook2 := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(hook1)
|
||||
manager.AddHook(hook2)
|
||||
|
||||
if manager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test ProcessOnGet
|
||||
ctx := context.Background()
|
||||
conn := &Conn{} // Mock connection
|
||||
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||
}
|
||||
|
||||
if hook1.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||
}
|
||||
|
||||
if hook2.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
|
||||
}
|
||||
|
||||
// Test ProcessOnPut
|
||||
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
if !shouldPool {
|
||||
t.Error("Expected shouldPool to be true")
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
t.Error("Expected shouldRemove to be false")
|
||||
}
|
||||
|
||||
if hook1.OnPutCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
|
||||
}
|
||||
|
||||
if hook2.OnPutCalled != 1 {
|
||||
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
|
||||
}
|
||||
|
||||
// Remove a hook
|
||||
manager.RemoveHook(hook1)
|
||||
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookErrorHandling(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Hook that returns error on Get
|
||||
errorHook := &TestHook{
|
||||
GetError: errors.New("test error"),
|
||||
ShouldPool: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(errorHook)
|
||||
manager.AddHook(normalHook)
|
||||
|
||||
ctx := context.Background()
|
||||
conn := &Conn{}
|
||||
|
||||
// Test that error stops processing
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ProcessOnGet")
|
||||
}
|
||||
|
||||
if errorHook.OnGetCalled != 1 {
|
||||
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||
}
|
||||
|
||||
// normalHook should not be called due to error
|
||||
if normalHook.OnGetCalled != 0 {
|
||||
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookShouldRemove(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Hook that says to remove connection
|
||||
removeHook := &TestHook{
|
||||
ShouldPool: false,
|
||||
ShouldRemove: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(removeHook)
|
||||
manager.AddHook(normalHook)
|
||||
|
||||
ctx := context.Background()
|
||||
conn := &Conn{}
|
||||
|
||||
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
if shouldPool {
|
||||
t.Error("Expected shouldPool to be false")
|
||||
}
|
||||
|
||||
if !shouldRemove {
|
||||
t.Error("Expected shouldRemove to be true")
|
||||
}
|
||||
|
||||
if removeHook.OnPutCalled != 1 {
|
||||
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
|
||||
}
|
||||
|
||||
// normalHook should not be called due to early return
|
||||
if normalHook.OnPutCalled != 0 {
|
||||
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithHooks(t *testing.T) {
|
||||
// Create a pool with hooks
|
||||
hookManager := NewPoolHookManager()
|
||||
testHook := &TestHook{ShouldPool: true}
|
||||
hookManager.AddHook(testHook)
|
||||
|
||||
opt := &Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil // Mock connection
|
||||
},
|
||||
PoolSize: 1,
|
||||
DialTimeout: time.Second,
|
||||
}
|
||||
|
||||
pool := NewConnPool(opt)
|
||||
defer pool.Close()
|
||||
|
||||
// Add hook to pool after creation
|
||||
pool.AddPoolHook(testHook)
|
||||
|
||||
// Verify hooks are initialized
|
||||
if pool.hookManager == nil {
|
||||
t.Error("Expected hookManager to be initialized")
|
||||
}
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test adding hook to pool
|
||||
additionalHook := &TestHook{ShouldPool: true}
|
||||
pool.AddPoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test removing hook from pool
|
||||
pool.RemovePoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -21,6 +23,23 @@ var (
|
||||
|
||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||
|
||||
// popAttempts is the maximum number of attempts to find a usable connection
|
||||
// when popping from the idle connection pool. This handles cases where connections
|
||||
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
||||
// Value of 50 provides sufficient resilience without excessive overhead.
|
||||
// This is capped by the idle connection count, so we won't loop excessively.
|
||||
popAttempts = 50
|
||||
|
||||
// getAttempts is the maximum number of attempts to get a connection that passes
|
||||
// hook validation (e.g., maintenanceNotifications upgrade hooks). This protects against race conditions
|
||||
// where hooks might temporarily reject connections during cluster transitions.
|
||||
// Value of 3 balances resilience with performance - most hook rejections resolve quickly.
|
||||
getAttempts = 3
|
||||
|
||||
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
|
||||
maxTime = minTime.Add(1<<63 - 1)
|
||||
noExpiration = maxTime
|
||||
)
|
||||
|
||||
var timers = sync.Pool{
|
||||
@@ -37,11 +56,14 @@ type Stats struct {
|
||||
Misses uint32 // number of times free connection was NOT found in the pool
|
||||
Timeouts uint32 // number of times a wait timeout occurred
|
||||
WaitCount uint32 // number of times a connection was waited
|
||||
Unusable uint32 // number of times a connection was found to be unusable
|
||||
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
|
||||
|
||||
TotalConns uint32 // number of total connections in the pool
|
||||
IdleConns uint32 // number of idle connections in the pool
|
||||
StaleConns uint32 // number of stale connections removed from the pool
|
||||
|
||||
PubSubStats PubSubStats
|
||||
}
|
||||
|
||||
type Pooler interface {
|
||||
@@ -56,24 +78,35 @@ type Pooler interface {
|
||||
IdleLen() int
|
||||
Stats() *Stats
|
||||
|
||||
AddPoolHook(hook PoolHook)
|
||||
RemovePoolHook(hook PoolHook)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Dialer func(context.Context) (net.Conn, error)
|
||||
|
||||
PoolFIFO bool
|
||||
PoolSize int
|
||||
DialTimeout time.Duration
|
||||
PoolTimeout time.Duration
|
||||
MinIdleConns int
|
||||
MaxIdleConns int
|
||||
MaxActiveConns int
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnMaxLifetime time.Duration
|
||||
|
||||
Dialer func(context.Context) (net.Conn, error)
|
||||
ReadBufferSize int
|
||||
WriteBufferSize int
|
||||
|
||||
PoolFIFO bool
|
||||
PoolSize int32
|
||||
DialTimeout time.Duration
|
||||
PoolTimeout time.Duration
|
||||
MinIdleConns int32
|
||||
MaxIdleConns int32
|
||||
MaxActiveConns int32
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnMaxLifetime time.Duration
|
||||
PushNotificationsEnabled bool
|
||||
|
||||
// DialerRetries is the maximum number of retry attempts when dialing fails.
|
||||
// Default: 5
|
||||
DialerRetries int
|
||||
|
||||
// DialerRetryTimeout is the backoff duration between retry attempts.
|
||||
// Default: 100ms
|
||||
DialerRetryTimeout time.Duration
|
||||
}
|
||||
|
||||
type lastDialErrorWrap struct {
|
||||
@@ -89,16 +122,21 @@ type ConnPool struct {
|
||||
queue chan struct{}
|
||||
|
||||
connsMu sync.Mutex
|
||||
conns []*Conn
|
||||
conns map[uint64]*Conn
|
||||
idleConns []*Conn
|
||||
|
||||
poolSize int
|
||||
idleConnsLen int
|
||||
poolSize atomic.Int32
|
||||
idleConnsLen atomic.Int32
|
||||
idleCheckInProgress atomic.Bool
|
||||
|
||||
stats Stats
|
||||
waitDurationNs atomic.Int64
|
||||
|
||||
_closed uint32 // atomic
|
||||
|
||||
// Pool hooks manager for flexible connection processing
|
||||
hookManagerMu sync.RWMutex
|
||||
hookManager *PoolHookManager
|
||||
}
|
||||
|
||||
var _ Pooler = (*ConnPool)(nil)
|
||||
@@ -108,34 +146,69 @@ func NewConnPool(opt *Options) *ConnPool {
|
||||
cfg: opt,
|
||||
|
||||
queue: make(chan struct{}, opt.PoolSize),
|
||||
conns: make([]*Conn, 0, opt.PoolSize),
|
||||
conns: make(map[uint64]*Conn),
|
||||
idleConns: make([]*Conn, 0, opt.PoolSize),
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
p.checkMinIdleConns()
|
||||
p.connsMu.Unlock()
|
||||
// Only create MinIdleConns if explicitly requested (> 0)
|
||||
// This avoids creating connections during pool initialization for tests
|
||||
if opt.MinIdleConns > 0 {
|
||||
p.connsMu.Lock()
|
||||
p.checkMinIdleConns()
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// initializeHooks sets up the pool hooks system.
|
||||
func (p *ConnPool) initializeHooks() {
|
||||
p.hookManager = NewPoolHookManager()
|
||||
}
|
||||
|
||||
// AddPoolHook adds a pool hook to the pool.
|
||||
func (p *ConnPool) AddPoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager == nil {
|
||||
p.initializeHooks()
|
||||
}
|
||||
p.hookManager.AddHook(hook)
|
||||
}
|
||||
|
||||
// RemovePoolHook removes a pool hook from the pool.
|
||||
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager != nil {
|
||||
p.hookManager.RemoveHook(hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) checkMinIdleConns() {
|
||||
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer p.idleCheckInProgress.Store(false)
|
||||
|
||||
if p.cfg.MinIdleConns == 0 {
|
||||
return
|
||||
}
|
||||
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
|
||||
|
||||
// Only create idle connections if we haven't reached the total pool size limit
|
||||
// MinIdleConns should be a subset of PoolSize, not additional connections
|
||||
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
p.poolSize++
|
||||
p.idleConnsLen++
|
||||
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.connsMu.Lock()
|
||||
p.poolSize--
|
||||
p.idleConnsLen--
|
||||
p.connsMu.Unlock()
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
@@ -144,12 +217,9 @@ func (p *ConnPool) checkMinIdleConns() {
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.connsMu.Lock()
|
||||
p.poolSize--
|
||||
p.idleConnsLen--
|
||||
p.connsMu.Unlock()
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
|
||||
p.freeTurn()
|
||||
}()
|
||||
default:
|
||||
@@ -166,6 +236,9 @@ func (p *ConnPool) addIdleConn() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
@@ -176,11 +249,15 @@ func (p *ConnPool) addIdleConn() error {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
p.conns = append(p.conns, cn)
|
||||
p.conns[cn.GetID()] = cn
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewConn creates a new connection and returns it to the user.
|
||||
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
|
||||
//
|
||||
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support maintnotifications upgrades.
|
||||
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
|
||||
return p.newConn(ctx, false)
|
||||
}
|
||||
@@ -190,33 +267,44 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
||||
p.connsMu.Unlock()
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
p.connsMu.Unlock()
|
||||
|
||||
cn, err := p.dialConn(ctx, pooled)
|
||||
dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout)
|
||||
defer cancel()
|
||||
cn, err := p.dialConn(dialCtx, pooled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||
_ = cn.Close()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
p.conns = append(p.conns, cn)
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
if p.closed() {
|
||||
_ = cn.Close()
|
||||
return nil, ErrClosed
|
||||
}
|
||||
// Check if pool was closed while we were waiting for the lock
|
||||
if p.conns == nil {
|
||||
p.conns = make(map[uint64]*Conn)
|
||||
}
|
||||
p.conns[cn.GetID()] = cn
|
||||
|
||||
if pooled {
|
||||
// If pool is full remove the cn on next Put.
|
||||
if p.poolSize >= p.cfg.PoolSize {
|
||||
currentPoolSize := p.poolSize.Load()
|
||||
if currentPoolSize >= p.cfg.PoolSize {
|
||||
cn.pooled = false
|
||||
} else {
|
||||
p.poolSize++
|
||||
p.poolSize.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,18 +320,57 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, p.getLastDialError()
|
||||
}
|
||||
|
||||
netConn, err := p.cfg.Dialer(ctx)
|
||||
if err != nil {
|
||||
p.setLastDialError(err)
|
||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
|
||||
go p.tryDial()
|
||||
}
|
||||
return nil, err
|
||||
// Retry dialing with backoff
|
||||
// the context timeout is already handled by the context passed in
|
||||
// so we may never reach the max retries, higher values don't hurt
|
||||
maxRetries := p.cfg.DialerRetries
|
||||
if maxRetries <= 0 {
|
||||
maxRetries = 5 // Default value
|
||||
}
|
||||
backoffDuration := p.cfg.DialerRetryTimeout
|
||||
if backoffDuration <= 0 {
|
||||
backoffDuration = 100 * time.Millisecond // Default value
|
||||
}
|
||||
|
||||
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
||||
cn.pooled = pooled
|
||||
return cn, nil
|
||||
var lastErr error
|
||||
shouldLoop := true
|
||||
// when the timeout is reached, we should stop retrying
|
||||
// but keep the lastErr to return to the caller
|
||||
// instead of a generic context deadline exceeded error
|
||||
for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ {
|
||||
netConn, err := p.cfg.Dialer(ctx)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
// Add backoff delay for retry attempts
|
||||
// (not for the first attempt, do at least one)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shouldLoop = false
|
||||
case <-time.After(backoffDuration):
|
||||
// Continue with retry
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Success - create connection
|
||||
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
||||
cn.pooled = pooled
|
||||
if p.cfg.ConnMaxLifetime > 0 {
|
||||
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
|
||||
} else {
|
||||
cn.expiresAt = noExpiration
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr)
|
||||
// All retries failed - handle error tracking
|
||||
p.setLastDialError(lastErr)
|
||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
|
||||
go p.tryDial()
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (p *ConnPool) tryDial() {
|
||||
@@ -283,6 +410,14 @@ func (p *ConnPool) getLastDialError() error {
|
||||
|
||||
// Get returns existed connection from the pool or creates a new one.
|
||||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return p.getConn(ctx)
|
||||
}
|
||||
|
||||
// getConn returns a connection from the pool.
|
||||
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
var cn *Conn
|
||||
var err error
|
||||
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
@@ -291,9 +426,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
attempts := 0
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
||||
break
|
||||
}
|
||||
attempts++
|
||||
|
||||
p.connsMu.Lock()
|
||||
cn, err := p.popIdle()
|
||||
cn, err = p.popIdle()
|
||||
p.connsMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -305,11 +448,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
break
|
||||
}
|
||||
|
||||
if !p.isHealthyConn(cn) {
|
||||
if !p.isHealthyConn(cn, now) {
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
// Failed to process connection, discard it
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint32(&p.stats.Hits, 1)
|
||||
return cn, nil
|
||||
}
|
||||
@@ -322,6 +479,19 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
||||
// Failed to process connection, discard it
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
|
||||
_ = p.CloseConn(newcn)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return newcn, nil
|
||||
}
|
||||
|
||||
@@ -350,7 +520,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
|
||||
}
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
|
||||
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
|
||||
atomic.AddUint32(&p.stats.WaitCount, 1)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
@@ -370,52 +540,130 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
defer p.checkMinIdleConns()
|
||||
|
||||
n := len(p.idleConns)
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var cn *Conn
|
||||
if p.cfg.PoolFIFO {
|
||||
cn = p.idleConns[0]
|
||||
copy(p.idleConns, p.idleConns[1:])
|
||||
p.idleConns = p.idleConns[:n-1]
|
||||
} else {
|
||||
idx := n - 1
|
||||
cn = p.idleConns[idx]
|
||||
p.idleConns = p.idleConns[:idx]
|
||||
attempts := 0
|
||||
|
||||
maxAttempts := util.Min(popAttempts, n)
|
||||
for attempts < maxAttempts {
|
||||
if len(p.idleConns) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.cfg.PoolFIFO {
|
||||
cn = p.idleConns[0]
|
||||
copy(p.idleConns, p.idleConns[1:])
|
||||
p.idleConns = p.idleConns[:len(p.idleConns)-1]
|
||||
} else {
|
||||
idx := len(p.idleConns) - 1
|
||||
cn = p.idleConns[idx]
|
||||
p.idleConns = p.idleConns[:idx]
|
||||
}
|
||||
attempts++
|
||||
|
||||
if cn.IsUsable() {
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
|
||||
// Connection is not usable, put it back in the pool
|
||||
if p.cfg.PoolFIFO {
|
||||
// FIFO: put at end (will be picked up last since we pop from front)
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
} else {
|
||||
// LIFO: put at beginning (will be picked up last since we pop from end)
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
}
|
||||
cn = nil
|
||||
}
|
||||
p.idleConnsLen--
|
||||
p.checkMinIdleConns()
|
||||
|
||||
// If we exhausted all attempts without finding a usable connection, return nil
|
||||
if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() {
|
||||
internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
if cn.rd.Buffered() > 0 {
|
||||
internal.Logger.Printf(ctx, "Conn has unread data")
|
||||
p.Remove(ctx, cn, BadConnError{})
|
||||
// Process connection using the hooks system
|
||||
shouldPool := true
|
||||
shouldRemove := false
|
||||
var err error
|
||||
|
||||
if cn.HasBufferedData() {
|
||||
// Peek at the reply type to check if it's a push notification
|
||||
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
|
||||
// Not a push notification or error peeking, remove connection
|
||||
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
|
||||
p.Remove(ctx, cn, err)
|
||||
}
|
||||
// It's a push notification, allow pooling (client will handle it)
|
||||
}
|
||||
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
|
||||
p.Remove(ctx, cn, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If hooks say to remove the connection, do so
|
||||
if shouldRemove {
|
||||
p.Remove(ctx, cn, errors.New("hook requested removal"))
|
||||
return
|
||||
}
|
||||
|
||||
// If processor says not to pool the connection, remove it
|
||||
if !shouldPool {
|
||||
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.pooled {
|
||||
p.Remove(ctx, cn, nil)
|
||||
p.Remove(ctx, cn, errors.New("connection not pooled"))
|
||||
return
|
||||
}
|
||||
|
||||
var shouldCloseConn bool
|
||||
|
||||
p.connsMu.Lock()
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.idleConnsLen++
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||
// put them at the opposite end of the queue
|
||||
if !cn.IsUsable() {
|
||||
if p.cfg.PoolFIFO {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.connsMu.Unlock()
|
||||
} else {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
} else {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
p.idleConnsLen.Add(1)
|
||||
} else {
|
||||
p.removeConn(cn)
|
||||
p.removeConnWithLock(cn)
|
||||
shouldCloseConn = true
|
||||
}
|
||||
|
||||
p.connsMu.Unlock()
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
if shouldCloseConn {
|
||||
@@ -425,8 +673,13 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
|
||||
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
p.removeConnWithLock(cn)
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
_ = p.closeConn(cn)
|
||||
|
||||
// Check if we need to create new idle connections to maintain MinIdleConns
|
||||
p.checkMinIdleConns()
|
||||
}
|
||||
|
||||
func (p *ConnPool) CloseConn(cn *Conn) error {
|
||||
@@ -441,17 +694,23 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
|
||||
}
|
||||
|
||||
func (p *ConnPool) removeConn(cn *Conn) {
|
||||
for i, c := range p.conns {
|
||||
if c == cn {
|
||||
p.conns = append(p.conns[:i], p.conns[i+1:]...)
|
||||
if cn.pooled {
|
||||
p.poolSize--
|
||||
p.checkMinIdleConns()
|
||||
cid := cn.GetID()
|
||||
delete(p.conns, cid)
|
||||
atomic.AddUint32(&p.stats.StaleConns, 1)
|
||||
|
||||
// Decrement pool size counter when removing a connection
|
||||
if cn.pooled {
|
||||
p.poolSize.Add(-1)
|
||||
// this can be idle conn
|
||||
for idx, ic := range p.idleConns {
|
||||
if ic.GetID() == cid {
|
||||
internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid)
|
||||
p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...)
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
atomic.AddUint32(&p.stats.StaleConns, 1)
|
||||
}
|
||||
|
||||
func (p *ConnPool) closeConn(cn *Conn) error {
|
||||
@@ -469,9 +728,9 @@ func (p *ConnPool) Len() int {
|
||||
// IdleLen returns number of idle connections.
|
||||
func (p *ConnPool) IdleLen() int {
|
||||
p.connsMu.Lock()
|
||||
n := p.idleConnsLen
|
||||
n := p.idleConnsLen.Load()
|
||||
p.connsMu.Unlock()
|
||||
return n
|
||||
return int(n)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Stats() *Stats {
|
||||
@@ -480,6 +739,7 @@ func (p *ConnPool) Stats() *Stats {
|
||||
Misses: atomic.LoadUint32(&p.stats.Misses),
|
||||
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
|
||||
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
|
||||
Unusable: atomic.LoadUint32(&p.stats.Unusable),
|
||||
WaitDurationNs: p.waitDurationNs.Load(),
|
||||
|
||||
TotalConns: uint32(p.Len()),
|
||||
@@ -520,28 +780,45 @@ func (p *ConnPool) Close() error {
|
||||
}
|
||||
}
|
||||
p.conns = nil
|
||||
p.poolSize = 0
|
||||
p.poolSize.Store(0)
|
||||
p.idleConns = nil
|
||||
p.idleConnsLen = 0
|
||||
p.idleConnsLen.Store(0)
|
||||
p.connsMu.Unlock()
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
||||
now := time.Now()
|
||||
|
||||
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
|
||||
// slight optimization, check expiresAt first.
|
||||
if cn.expiresAt.Before(now) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if connection has exceeded idle timeout
|
||||
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||
return false
|
||||
}
|
||||
|
||||
if connCheck(cn.netConn) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
cn.SetUsedAt(now)
|
||||
// Check basic connection health
|
||||
// Use GetNetConn() to safely access netConn and avoid data races
|
||||
if err := connCheck(cn.getNetConn()); err != nil {
|
||||
// If there's unexpected data, it might be push notifications (RESP3)
|
||||
// However, push notification processing is now handled by the client
|
||||
// before WithReader to ensure proper context is available to handlers
|
||||
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block
|
||||
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
||||
// For RESP3 connections with push notifications, we allow some buffered data
|
||||
// The client will process these notifications before using the connection
|
||||
internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID())
|
||||
return true // Connection is healthy, client will handle notifications
|
||||
}
|
||||
return false // Unexpected data, not push notifications, connection is unhealthy
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package pool
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type SingleConnPool struct {
|
||||
pool Pooler
|
||||
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
||||
@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
|
||||
func (p *StickyConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
||||
@@ -2,15 +2,17 @@ package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
var _ = Describe("ConnPool", func() {
|
||||
@@ -20,7 +22,7 @@ var _ = Describe("ConnPool", func() {
|
||||
BeforeEach(func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Hour,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
@@ -45,11 +47,11 @@ var _ = Describe("ConnPool", func() {
|
||||
<-closedChan
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Hour,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
MinIdleConns: minIdleConns,
|
||||
MinIdleConns: int32(minIdleConns),
|
||||
})
|
||||
wg.Wait()
|
||||
Expect(connPool.Close()).NotTo(HaveOccurred())
|
||||
@@ -105,7 +107,7 @@ var _ = Describe("ConnPool", func() {
|
||||
// ok
|
||||
}
|
||||
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
|
||||
// Check that Get is unblocked.
|
||||
select {
|
||||
@@ -130,8 +132,8 @@ var _ = Describe("MinIdleConns", func() {
|
||||
newConnPool := func() *pool.ConnPool {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
PoolSize: int32(poolSize),
|
||||
MinIdleConns: int32(minIdleConns),
|
||||
PoolTimeout: 100 * time.Millisecond,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: -1,
|
||||
@@ -168,7 +170,7 @@ var _ = Describe("MinIdleConns", func() {
|
||||
|
||||
Context("after Remove", func() {
|
||||
BeforeEach(func() {
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
})
|
||||
|
||||
It("has idle connections", func() {
|
||||
@@ -245,7 +247,7 @@ var _ = Describe("MinIdleConns", func() {
|
||||
BeforeEach(func() {
|
||||
perform(len(cns), func(i int) {
|
||||
mu.RLock()
|
||||
connPool.Remove(ctx, cns[i], nil)
|
||||
connPool.Remove(ctx, cns[i], errors.New("test"))
|
||||
mu.RUnlock()
|
||||
})
|
||||
|
||||
@@ -309,7 +311,7 @@ var _ = Describe("race", func() {
|
||||
It("does not happen on Get, Put, and Remove", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Minute,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
@@ -328,7 +330,7 @@ var _ = Describe("race", func() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -339,15 +341,15 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1000,
|
||||
MinIdleConns: 50,
|
||||
PoolSize: int32(1000),
|
||||
MinIdleConns: int32(50),
|
||||
PoolTimeout: 3 * time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < opt.PoolSize; i++ {
|
||||
for i := int32(0); i < opt.PoolSize; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
@@ -366,8 +368,8 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
panic("test panic")
|
||||
},
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 30,
|
||||
PoolSize: int32(100),
|
||||
MinIdleConns: int32(30),
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
||||
@@ -377,14 +379,14 @@ var _ = Describe("race", func() {
|
||||
state := p.Stats()
|
||||
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
|
||||
}, "3s", "50ms").Should(BeTrue())
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
It("wait", func() {
|
||||
opt := &pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 3 * time.Second,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
@@ -415,7 +417,7 @@ var _ = Describe("race", func() {
|
||||
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: testPoolTimeout,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
@@ -435,3 +437,73 @@ var _ = Describe("race", func() {
|
||||
Expect(stats.Timeouts).To(Equal(uint32(1)))
|
||||
})
|
||||
})
|
||||
|
||||
// TestDialerRetryConfiguration tests the new DialerRetries and DialerRetryTimeout options
|
||||
func TestDialerRetryConfiguration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CustomDialerRetries", func(t *testing.T) {
|
||||
var attempts int64
|
||||
failingDialer := func(ctx context.Context) (net.Conn, error) {
|
||||
atomic.AddInt64(&attempts, 1)
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: failingDialer,
|
||||
PoolSize: 1,
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
DialerRetries: 3, // Custom retry count
|
||||
DialerRetryTimeout: 10 * time.Millisecond, // Fast retries for testing
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
_, err := connPool.Get(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error from failing dialer")
|
||||
}
|
||||
|
||||
// Should have attempted at least 3 times (DialerRetries = 3)
|
||||
// There might be additional attempts due to pool logic
|
||||
finalAttempts := atomic.LoadInt64(&attempts)
|
||||
if finalAttempts < 3 {
|
||||
t.Errorf("Expected at least 3 dial attempts, got %d", finalAttempts)
|
||||
}
|
||||
if finalAttempts > 6 {
|
||||
t.Errorf("Expected around 3 dial attempts, got %d (too many)", finalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DefaultDialerRetries", func(t *testing.T) {
|
||||
var attempts int64
|
||||
failingDialer := func(ctx context.Context) (net.Conn, error) {
|
||||
atomic.AddInt64(&attempts, 1)
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: failingDialer,
|
||||
PoolSize: 1,
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
// DialerRetries and DialerRetryTimeout not set - should use defaults
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
_, err := connPool.Get(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error from failing dialer")
|
||||
}
|
||||
|
||||
// Should have attempted 5 times (default DialerRetries = 5)
|
||||
finalAttempts := atomic.LoadInt64(&attempts)
|
||||
if finalAttempts != 5 {
|
||||
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
logging.Disable()
|
||||
}
|
||||
|
||||
78
internal/pool/pubsub.go
Normal file
78
internal/pool/pubsub.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type PubSubStats struct {
|
||||
Created uint32
|
||||
Untracked uint32
|
||||
Active uint32
|
||||
}
|
||||
|
||||
// PubSubPool manages a pool of PubSub connections.
|
||||
type PubSubPool struct {
|
||||
opt *Options
|
||||
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Map to track active PubSub connections
|
||||
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
|
||||
closed atomic.Bool
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
opt: opt,
|
||||
netDialer: netDialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
netConn, err := p.netDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
|
||||
cn.pubsub = true
|
||||
atomic.AddUint32(&p.stats.Created, 1)
|
||||
return cn, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *PubSubPool) TrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, 1)
|
||||
p.activeConns.Store(cn.GetID(), cn)
|
||||
}
|
||||
|
||||
func (p *PubSubPool) UntrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, ^uint32(0))
|
||||
atomic.AddUint32(&p.stats.Untracked, 1)
|
||||
p.activeConns.Delete(cn.GetID())
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Close() error {
|
||||
p.closed.Store(true)
|
||||
p.activeConns.Range(func(key, value interface{}) bool {
|
||||
cn := value.(*Conn)
|
||||
_ = cn.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Stats() *PubSubStats {
|
||||
// load stats atomically
|
||||
return &PubSubStats{
|
||||
Created: atomic.LoadUint32(&p.stats.Created),
|
||||
Untracked: atomic.LoadUint32(&p.stats.Untracked),
|
||||
Active: atomic.LoadUint32(&p.stats.Active),
|
||||
}
|
||||
}
|
||||
614
internal/proto/peek_push_notification_test.go
Normal file
614
internal/proto/peek_push_notification_test.go
Normal file
@@ -0,0 +1,614 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPeekPushNotificationName tests the updated PeekPushNotificationName method
|
||||
func TestPeekPushNotificationName(t *testing.T) {
|
||||
t.Run("ValidPushNotifications", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
notification string
|
||||
expected string
|
||||
}{
|
||||
{"MOVING", "MOVING", "MOVING"},
|
||||
{"MIGRATING", "MIGRATING", "MIGRATING"},
|
||||
{"MIGRATED", "MIGRATED", "MIGRATED"},
|
||||
{"FAILING_OVER", "FAILING_OVER", "FAILING_OVER"},
|
||||
{"FAILED_OVER", "FAILED_OVER", "FAILED_OVER"},
|
||||
{"message", "message", "message"},
|
||||
{"pmessage", "pmessage", "pmessage"},
|
||||
{"subscribe", "subscribe", "subscribe"},
|
||||
{"unsubscribe", "unsubscribe", "unsubscribe"},
|
||||
{"psubscribe", "psubscribe", "psubscribe"},
|
||||
{"punsubscribe", "punsubscribe", "punsubscribe"},
|
||||
{"smessage", "smessage", "smessage"},
|
||||
{"ssubscribe", "ssubscribe", "ssubscribe"},
|
||||
{"sunsubscribe", "sunsubscribe", "sunsubscribe"},
|
||||
{"custom", "custom", "custom"},
|
||||
{"short", "a", "a"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := createValidPushNotification(tc.notification, "data")
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Prime the buffer by peeking first
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error for valid notification: %v", err)
|
||||
}
|
||||
|
||||
if name != tc.expected {
|
||||
t.Errorf("Expected notification name '%s', got '%s'", tc.expected, name)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotificationWithMultipleArguments", func(t *testing.T) {
|
||||
// Create push notification with multiple arguments
|
||||
buf := createPushNotificationWithArgs("MOVING", "slot", "123", "from", "node1", "to", "node2")
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Prime the buffer
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error: %v", err)
|
||||
}
|
||||
|
||||
if name != "MOVING" {
|
||||
t.Errorf("Expected 'MOVING', got '%s'", name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SingleElementNotification", func(t *testing.T) {
|
||||
// Create push notification with single element
|
||||
buf := createSingleElementPushNotification("TEST")
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Prime the buffer
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error: %v", err)
|
||||
}
|
||||
|
||||
if name != "TEST" {
|
||||
t.Errorf("Expected 'TEST', got '%s'", name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorDetection", func(t *testing.T) {
|
||||
t.Run("NotPushNotification", func(t *testing.T) {
|
||||
// Test with regular array instead of push notification
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("*2\r\n$6\r\nMOVING\r\n$4\r\ndata\r\n")
|
||||
reader := NewReader(buf)
|
||||
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err == nil {
|
||||
t.Error("PeekPushNotificationName should error for non-push notification")
|
||||
}
|
||||
|
||||
// The error might be "no data available" or "can't parse push notification"
|
||||
if !strings.Contains(err.Error(), "can't peek push notification name") {
|
||||
t.Errorf("Error should mention push notification parsing, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsufficientData", func(t *testing.T) {
|
||||
// Test with buffer smaller than peek size - this might panic due to bounds checking
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(">")
|
||||
reader := NewReader(buf)
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("PeekPushNotificationName panicked as expected for insufficient data: %v", r)
|
||||
}
|
||||
}()
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err == nil {
|
||||
t.Error("PeekPushNotificationName should error for insufficient data")
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
t.Run("EmptyBuffer", func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
reader := NewReader(buf)
|
||||
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err == nil {
|
||||
t.Error("PeekPushNotificationName should error for empty buffer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DifferentRESPTypes", func(t *testing.T) {
|
||||
// Test with different RESP types that should be rejected
|
||||
respTypes := []byte{'+', '-', ':', '$', '*', '%', '~', '|', '('}
|
||||
|
||||
for _, respType := range respTypes {
|
||||
t.Run(fmt.Sprintf("Type_%c", respType), func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte(respType)
|
||||
buf.WriteString("test data that fills the buffer completely")
|
||||
reader := NewReader(buf)
|
||||
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err == nil {
|
||||
t.Errorf("PeekPushNotificationName should error for RESP type '%c'", respType)
|
||||
}
|
||||
|
||||
// The error might be "no data available" or "can't parse push notification"
|
||||
if !strings.Contains(err.Error(), "can't peek push notification name") {
|
||||
t.Errorf("Error should mention push notification parsing, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("EdgeCases", func(t *testing.T) {
|
||||
t.Run("ZeroLengthArray", func(t *testing.T) {
|
||||
// Create push notification with zero elements: >0\r\n
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(">0\r\npadding_data_to_fill_buffer_completely")
|
||||
reader := NewReader(buf)
|
||||
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err == nil {
|
||||
t.Error("PeekPushNotificationName should error for zero-length array")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyNotificationName", func(t *testing.T) {
|
||||
// Create push notification with empty name: >1\r\n$0\r\n\r\n
|
||||
buf := createValidPushNotification("", "data")
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Prime the buffer
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error for empty name: %v", err)
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
t.Errorf("Expected empty notification name, got '%s'", name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CorruptedData", func(t *testing.T) {
|
||||
corruptedCases := []struct {
|
||||
name string
|
||||
data string
|
||||
}{
|
||||
{"CorruptedLength", ">abc\r\n$6\r\nMOVING\r\n"},
|
||||
{"MissingCRLF", ">2$6\r\nMOVING\r\n$4\r\ndata\r\n"},
|
||||
{"InvalidStringLength", ">2\r\n$abc\r\nMOVING\r\n$4\r\ndata\r\n"},
|
||||
{"NegativeStringLength", ">2\r\n$-1\r\n$4\r\ndata\r\n"},
|
||||
{"IncompleteString", ">1\r\n$6\r\nMOV"},
|
||||
}
|
||||
|
||||
for _, tc := range corruptedCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tc.data)
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Some corrupted data might not error but return unexpected results
|
||||
// This is acceptable behavior for malformed input
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data)
|
||||
} else {
|
||||
t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("BoundaryConditions", func(t *testing.T) {
|
||||
t.Run("ExactlyPeekSize", func(t *testing.T) {
|
||||
// Create buffer that is exactly 36 bytes (the peek window size)
|
||||
buf := &bytes.Buffer{}
|
||||
// ">1\r\n$4\r\nTEST\r\n" = 14 bytes, need 22 more
|
||||
buf.WriteString(">1\r\n$4\r\nTEST\r\n1234567890123456789012")
|
||||
if buf.Len() != 36 {
|
||||
t.Errorf("Expected buffer length 36, got %d", buf.Len())
|
||||
}
|
||||
|
||||
reader := NewReader(buf)
|
||||
// Prime the buffer
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should work for exact peek size: %v", err)
|
||||
}
|
||||
|
||||
if name != "TEST" {
|
||||
t.Errorf("Expected 'TEST', got '%s'", name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LessThanPeekSize", func(t *testing.T) {
|
||||
// Create buffer smaller than 36 bytes but with complete notification
|
||||
buf := createValidPushNotification("TEST", "")
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Prime the buffer
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should work for complete notification: %v", err)
|
||||
}
|
||||
|
||||
if name != "TEST" {
|
||||
t.Errorf("Expected 'TEST', got '%s'", name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongNotificationName", func(t *testing.T) {
|
||||
// Test with notification name that might exceed peek window
|
||||
longName := strings.Repeat("A", 20) // 20 character name (safe size)
|
||||
buf := createValidPushNotification(longName, "data")
|
||||
reader := NewReader(buf)
|
||||
|
||||
// Prime the buffer
|
||||
_, _ = reader.rd.Peek(1)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should work for long name: %v", err)
|
||||
}
|
||||
|
||||
if name != longName {
|
||||
t.Errorf("Expected '%s', got '%s'", longName, name)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions to create test data
|
||||
|
||||
// createValidPushNotification creates a valid RESP3 push notification
|
||||
func createValidPushNotification(notificationName, data string) *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
simpleOrString := rand.Intn(2) == 0
|
||||
|
||||
if data == "" {
|
||||
|
||||
// Single element notification
|
||||
buf.WriteString(">1\r\n")
|
||||
if simpleOrString {
|
||||
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
}
|
||||
} else {
|
||||
// Two element notification
|
||||
buf.WriteString(">2\r\n")
|
||||
if simpleOrString {
|
||||
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
|
||||
buf.WriteString(fmt.Sprintf("+%s\r\n", data))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// createReaderWithPrimedBuffer creates a reader and primes the buffer
|
||||
func createReaderWithPrimedBuffer(buf *bytes.Buffer) *Reader {
|
||||
reader := NewReader(buf)
|
||||
// Prime the buffer by peeking first
|
||||
_, _ = reader.rd.Peek(1)
|
||||
return reader
|
||||
}
|
||||
|
||||
// createPushNotificationWithArgs creates a push notification with multiple arguments
|
||||
func createPushNotificationWithArgs(notificationName string, args ...string) *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
totalElements := 1 + len(args)
|
||||
buf.WriteString(fmt.Sprintf(">%d\r\n", totalElements))
|
||||
|
||||
// Write notification name
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
|
||||
// Write arguments
|
||||
for _, arg := range args {
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(arg), arg))
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// createSingleElementPushNotification creates a push notification with single element
|
||||
func createSingleElementPushNotification(notificationName string) *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(">1\r\n")
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
return buf
|
||||
}
|
||||
|
||||
// BenchmarkPeekPushNotificationName benchmarks the method performance
|
||||
func BenchmarkPeekPushNotificationName(b *testing.B) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
notification string
|
||||
}{
|
||||
{"Short", "TEST"},
|
||||
{"Medium", "MOVING_NOTIFICATION"},
|
||||
{"Long", "VERY_LONG_NOTIFICATION_NAME_FOR_TESTING"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
buf := createValidPushNotification(tc.notification, "data")
|
||||
data := buf.Bytes()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader := NewReader(bytes.NewReader(data))
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
b.Errorf("PeekPushNotificationName should not error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeekPushNotificationNameSpecialCases tests special cases and realistic scenarios
|
||||
func TestPeekPushNotificationNameSpecialCases(t *testing.T) {
|
||||
t.Run("RealisticNotifications", func(t *testing.T) {
|
||||
// Test realistic Redis push notifications
|
||||
realisticCases := []struct {
|
||||
name string
|
||||
notification []string
|
||||
expected string
|
||||
}{
|
||||
{"MovingSlot", []string{"MOVING", "slot", "123", "from", "127.0.0.1:7000", "to", "127.0.0.1:7001"}, "MOVING"},
|
||||
{"MigratingSlot", []string{"MIGRATING", "slot", "456", "from", "127.0.0.1:7001", "to", "127.0.0.1:7002"}, "MIGRATING"},
|
||||
{"MigratedSlot", []string{"MIGRATED", "slot", "789", "from", "127.0.0.1:7002", "to", "127.0.0.1:7000"}, "MIGRATED"},
|
||||
{"FailingOver", []string{"FAILING_OVER", "node", "127.0.0.1:7000"}, "FAILING_OVER"},
|
||||
{"FailedOver", []string{"FAILED_OVER", "node", "127.0.0.1:7000"}, "FAILED_OVER"},
|
||||
{"PubSubMessage", []string{"message", "mychannel", "hello world"}, "message"},
|
||||
{"PubSubPMessage", []string{"pmessage", "pattern*", "mychannel", "hello world"}, "pmessage"},
|
||||
{"Subscribe", []string{"subscribe", "mychannel", "1"}, "subscribe"},
|
||||
{"Unsubscribe", []string{"unsubscribe", "mychannel", "0"}, "unsubscribe"},
|
||||
}
|
||||
|
||||
for _, tc := range realisticCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := createPushNotificationWithArgs(tc.notification[0], tc.notification[1:]...)
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error for %s: %v", tc.name, err)
|
||||
}
|
||||
|
||||
if name != tc.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tc.expected, name)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInName", func(t *testing.T) {
|
||||
specialCases := []struct {
|
||||
name string
|
||||
notification string
|
||||
}{
|
||||
{"WithUnderscore", "test_notification"},
|
||||
{"WithDash", "test-notification"},
|
||||
{"WithNumbers", "test123"},
|
||||
{"WithDots", "test.notification"},
|
||||
{"WithColon", "test:notification"},
|
||||
{"WithSlash", "test/notification"},
|
||||
{"MixedCase", "TestNotification"},
|
||||
{"AllCaps", "TESTNOTIFICATION"},
|
||||
{"AllLower", "testnotification"},
|
||||
{"Unicode", "tëst"},
|
||||
}
|
||||
|
||||
for _, tc := range specialCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := createValidPushNotification(tc.notification, "data")
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error for '%s': %v", tc.notification, err)
|
||||
}
|
||||
|
||||
if name != tc.notification {
|
||||
t.Errorf("Expected '%s', got '%s'", tc.notification, name)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IdempotentPeek", func(t *testing.T) {
|
||||
// Test that multiple peeks return the same result
|
||||
buf := createValidPushNotification("MOVING", "data")
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
// First peek
|
||||
name1, err1 := reader.PeekPushNotificationName()
|
||||
if err1 != nil {
|
||||
t.Errorf("First PeekPushNotificationName should not error: %v", err1)
|
||||
}
|
||||
|
||||
// Second peek should return the same result
|
||||
name2, err2 := reader.PeekPushNotificationName()
|
||||
if err2 != nil {
|
||||
t.Errorf("Second PeekPushNotificationName should not error: %v", err2)
|
||||
}
|
||||
|
||||
if name1 != name2 {
|
||||
t.Errorf("Peek should be idempotent: first='%s', second='%s'", name1, name2)
|
||||
}
|
||||
|
||||
if name1 != "MOVING" {
|
||||
t.Errorf("Expected 'MOVING', got '%s'", name1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPeekPushNotificationNamePerformance tests performance characteristics
|
||||
func TestPeekPushNotificationNamePerformance(t *testing.T) {
|
||||
t.Run("RepeatedCalls", func(t *testing.T) {
|
||||
// Test that repeated calls work correctly
|
||||
buf := createValidPushNotification("TEST", "data")
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
// Call multiple times
|
||||
for i := 0; i < 10; i++ {
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error on call %d: %v", i, err)
|
||||
}
|
||||
if name != "TEST" {
|
||||
t.Errorf("Expected 'TEST' on call %d, got '%s'", i, name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LargeNotifications", func(t *testing.T) {
|
||||
// Test with large notification data
|
||||
largeData := strings.Repeat("x", 1000)
|
||||
buf := createValidPushNotification("LARGE", largeData)
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error for large notification: %v", err)
|
||||
}
|
||||
|
||||
if name != "LARGE" {
|
||||
t.Errorf("Expected 'LARGE', got '%s'", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPeekPushNotificationNameBehavior documents the method's behavior
|
||||
func TestPeekPushNotificationNameBehavior(t *testing.T) {
|
||||
t.Run("MethodBehavior", func(t *testing.T) {
|
||||
// Test that the method works as intended:
|
||||
// 1. Peek at the buffer without consuming it
|
||||
// 2. Detect push notifications (RESP type '>')
|
||||
// 3. Extract the notification name from the first element
|
||||
// 4. Return the name for filtering decisions
|
||||
|
||||
buf := createValidPushNotification("MOVING", "slot_data")
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
// Peek should not consume the buffer
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error: %v", err)
|
||||
}
|
||||
|
||||
if name != "MOVING" {
|
||||
t.Errorf("Expected 'MOVING', got '%s'", name)
|
||||
}
|
||||
|
||||
// Buffer should still be available for normal reading
|
||||
replyType, err := reader.PeekReplyType()
|
||||
if err != nil {
|
||||
t.Errorf("PeekReplyType should work after PeekPushNotificationName: %v", err)
|
||||
}
|
||||
|
||||
if replyType != RespPush {
|
||||
t.Errorf("Expected RespPush, got %v", replyType)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BufferNotConsumed", func(t *testing.T) {
|
||||
// Verify that peeking doesn't consume the buffer
|
||||
buf := createValidPushNotification("TEST", "data")
|
||||
originalData := buf.Bytes()
|
||||
reader := createReaderWithPrimedBuffer(buf)
|
||||
|
||||
// Peek the notification name
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Errorf("PeekPushNotificationName should not error: %v", err)
|
||||
}
|
||||
|
||||
if name != "TEST" {
|
||||
t.Errorf("Expected 'TEST', got '%s'", name)
|
||||
}
|
||||
|
||||
// Read the actual notification
|
||||
reply, err := reader.ReadReply()
|
||||
if err != nil {
|
||||
t.Errorf("ReadReply should work after peek: %v", err)
|
||||
}
|
||||
|
||||
// Verify we got the complete notification
|
||||
if replySlice, ok := reply.([]interface{}); ok {
|
||||
if len(replySlice) != 2 {
|
||||
t.Errorf("Expected 2 elements, got %d", len(replySlice))
|
||||
}
|
||||
if replySlice[0] != "TEST" {
|
||||
t.Errorf("Expected 'TEST', got %v", replySlice[0])
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected slice reply, got %T", reply)
|
||||
}
|
||||
|
||||
// Verify buffer was properly consumed
|
||||
if buf.Len() != 0 {
|
||||
t.Errorf("Buffer should be empty after reading, but has %d bytes: %q", buf.Len(), buf.Bytes())
|
||||
}
|
||||
|
||||
t.Logf("Original buffer size: %d bytes", len(originalData))
|
||||
t.Logf("Successfully peeked and then read complete notification")
|
||||
})
|
||||
|
||||
t.Run("ImplementationSuccess", func(t *testing.T) {
|
||||
// Document that the implementation is now working correctly
|
||||
t.Log("PeekPushNotificationName implementation status:")
|
||||
t.Log("1. ✅ Correctly parses RESP3 push notifications")
|
||||
t.Log("2. ✅ Extracts notification names properly")
|
||||
t.Log("3. ✅ Handles buffer peeking without consumption")
|
||||
t.Log("4. ✅ Works with various notification types")
|
||||
t.Log("5. ✅ Supports empty notification names")
|
||||
t.Log("")
|
||||
t.Log("RESP3 format parsing:")
|
||||
t.Log(">2\\r\\n$6\\r\\nMOVING\\r\\n$4\\r\\ndata\\r\\n")
|
||||
t.Log("✅ Correctly identifies push notification marker (>)")
|
||||
t.Log("✅ Skips array length (2)")
|
||||
t.Log("✅ Parses string marker ($) and length (6)")
|
||||
t.Log("✅ Extracts notification name (MOVING)")
|
||||
t.Log("✅ Returns name without consuming buffer")
|
||||
t.Log("")
|
||||
t.Log("Note: Buffer must be primed with a peek operation first")
|
||||
})
|
||||
}
|
||||
@@ -99,6 +99,92 @@ func (r *Reader) PeekReplyType() (byte, error) {
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
func (r *Reader) PeekPushNotificationName() (string, error) {
|
||||
// "prime" the buffer by peeking at the next byte
|
||||
c, err := r.Peek(1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if c[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification")
|
||||
}
|
||||
|
||||
// peek 36 bytes at most, should be enough to read the push notification name
|
||||
toPeek := 36
|
||||
buffered := r.Buffered()
|
||||
if buffered == 0 {
|
||||
return "", fmt.Errorf("redis: can't peek push notification name, no data available")
|
||||
}
|
||||
if buffered < toPeek {
|
||||
toPeek = buffered
|
||||
}
|
||||
buf, err := r.rd.Peek(toPeek)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if buf[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
if len(buf) < 3 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
// remove push notification type
|
||||
buf = buf[1:]
|
||||
// remove first line - e.g. >2\r\n
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
// next line should be $<length><string>\r\n or +<length><string>\r\n
|
||||
// should have the type of the push notification name and it's length
|
||||
if buf[0] != RespString && buf[0] != RespStatus {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
typeOfName := buf[0]
|
||||
// remove the type of the push notification name
|
||||
buf = buf[1:]
|
||||
if typeOfName == RespString {
|
||||
// remove the length of the string
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
// keep only the notification name
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return util.BytesToString(buf), nil
|
||||
}
|
||||
|
||||
// ReadLine Return a valid reply, it will check the protocol or redis error,
|
||||
// and discard the attribute type.
|
||||
func (r *Reader) ReadLine() ([]byte, error) {
|
||||
|
||||
3
internal/redis.go
Normal file
3
internal/redis.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package internal
|
||||
|
||||
const RedisNull = "<nil>"
|
||||
@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
|
||||
func SafeIntToInt32(value int, fieldName string) (int32, error) {
|
||||
if value > math.MaxInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
|
||||
}
|
||||
if value < math.MinInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
|
||||
}
|
||||
return int32(value), nil
|
||||
}
|
||||
|
||||
17
internal/util/math.go
Normal file
17
internal/util/math.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
// Max returns the maximum of two integers
|
||||
func Max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Min returns the minimum of two integers
|
||||
func Min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
. "github.com/bsm/gomega"
|
||||
)
|
||||
|
||||
var ctx = context.TODO()
|
||||
|
||||
var _ = Describe("newClusterState", func() {
|
||||
var state *clusterState
|
||||
|
||||
|
||||
91
logging/logging.go
Normal file
91
logging/logging.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Package logging provides logging level constants and utilities for the go-redis library.
|
||||
// This package centralizes logging configuration to ensure consistency across all components.
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
type LogLevelT = internal.LogLevelT
|
||||
|
||||
const (
|
||||
LogLevelError = internal.LogLevelError
|
||||
LogLevelWarn = internal.LogLevelWarn
|
||||
LogLevelInfo = internal.LogLevelInfo
|
||||
LogLevelDebug = internal.LogLevelDebug
|
||||
)
|
||||
|
||||
// VoidLogger is a logger that does nothing.
|
||||
// Used to disable logging and thus speed up the library.
|
||||
type VoidLogger struct{}
|
||||
|
||||
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// Disable disables logging by setting the internal logger to a void logger.
|
||||
// This can be used to speed up the library if logging is not needed.
|
||||
// It will override any custom logger that was set before and set the VoidLogger.
|
||||
func Disable() {
|
||||
internal.Logger = &VoidLogger{}
|
||||
}
|
||||
|
||||
// Enable enables logging by setting the internal logger to the default logger.
|
||||
// This is the default behavior.
|
||||
// You can use redis.SetLogger to set a custom logger.
|
||||
//
|
||||
// NOTE: This function is not thread-safe.
|
||||
// It will override any custom logger that was set before and set the DefaultLogger.
|
||||
func Enable() {
|
||||
internal.Logger = internal.NewDefaultLogger()
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the library.
|
||||
func SetLogLevel(logLevel LogLevelT) {
|
||||
internal.LogLevel = logLevel
|
||||
}
|
||||
|
||||
// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings.
|
||||
// This can be used to filter out messages containing sensitive information.
|
||||
func NewBlacklistLogger(substr []string) internal.Logging {
|
||||
l := internal.NewDefaultLogger()
|
||||
return &filterLogger{logger: l, substr: substr, blacklist: true}
|
||||
}
|
||||
|
||||
// NewWhitelistLogger returns a new logger that only logs messages containing any of the substrings.
|
||||
// This can be used to only log messages related to specific commands or patterns.
|
||||
func NewWhitelistLogger(substr []string) internal.Logging {
|
||||
l := internal.NewDefaultLogger()
|
||||
return &filterLogger{logger: l, substr: substr, blacklist: false}
|
||||
}
|
||||
|
||||
type filterLogger struct {
|
||||
logger internal.Logging
|
||||
blacklist bool
|
||||
substr []string
|
||||
}
|
||||
|
||||
func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
found := false
|
||||
for _, substr := range l.substr {
|
||||
if strings.Contains(msg, substr) {
|
||||
found = true
|
||||
if l.blacklist {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// whitelist, only log if one of the substrings is present
|
||||
if !l.blacklist && !found {
|
||||
return
|
||||
}
|
||||
if l.logger != nil {
|
||||
l.logger.Printf(ctx, format, v...)
|
||||
return
|
||||
}
|
||||
}
|
||||
59
logging/logging_test.go
Normal file
59
logging/logging_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package logging
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLogLevel_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
level LogLevelT
|
||||
expected string
|
||||
}{
|
||||
{LogLevelError, "ERROR"},
|
||||
{LogLevelWarn, "WARN"},
|
||||
{LogLevelInfo, "INFO"},
|
||||
{LogLevelDebug, "DEBUG"},
|
||||
{LogLevelT(99), "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if got := test.level.String(); got != test.expected {
|
||||
t.Errorf("LogLevel(%d).String() = %q, want %q", test.level, got, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevel_IsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
level LogLevelT
|
||||
expected bool
|
||||
}{
|
||||
{LogLevelError, true},
|
||||
{LogLevelWarn, true},
|
||||
{LogLevelInfo, true},
|
||||
{LogLevelDebug, true},
|
||||
{LogLevelT(-1), false},
|
||||
{LogLevelT(4), false},
|
||||
{LogLevelT(99), false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if got := test.level.IsValid(); got != test.expected {
|
||||
t.Errorf("LogLevel(%d).IsValid() = %v, want %v", test.level, got, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevelConstants(t *testing.T) {
|
||||
// Test that constants have expected values
|
||||
if LogLevelError != 0 {
|
||||
t.Errorf("LogLevelError = %d, want 0", LogLevelError)
|
||||
}
|
||||
if LogLevelWarn != 1 {
|
||||
t.Errorf("LogLevelWarn = %d, want 1", LogLevelWarn)
|
||||
}
|
||||
if LogLevelInfo != 2 {
|
||||
t.Errorf("LogLevelInfo = %d, want 2", LogLevelInfo)
|
||||
}
|
||||
if LogLevelDebug != 3 {
|
||||
t.Errorf("LogLevelDebug = %d, want 3", LogLevelDebug)
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -102,6 +103,7 @@ var _ = BeforeSuite(func() {
|
||||
fmt.Printf("RCEDocker: %v\n", RCEDocker)
|
||||
fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion)
|
||||
fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE"))
|
||||
logging.Disable()
|
||||
|
||||
if RedisVersion < 7.0 || RedisVersion > 9 {
|
||||
panic("incorrect or not supported redis version")
|
||||
|
||||
100
maintnotifications/README.md
Normal file
100
maintnotifications/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# Maintenance Notifications
|
||||
|
||||
Seamless Redis connection handoffs during cluster maintenance operations without dropping connections.
|
||||
|
||||
## ⚠️ **Important Note**
|
||||
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- **`ModeDisabled`** - Maintenance notifications disabled
|
||||
- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support)
|
||||
- **`ModeAuto`** - Auto-detect server support (default)
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
&maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeAuto,
|
||||
EndpointType: maintnotifications.EndpointTypeAuto,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxHandoffRetries: 3,
|
||||
MaxWorkers: 0, // Auto-calculated
|
||||
HandoffQueueSize: 0, // Auto-calculated
|
||||
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
|
||||
}
|
||||
```
|
||||
|
||||
### Endpoint Types
|
||||
|
||||
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
|
||||
- **`EndpointTypeInternalIP`** - Internal IP address
|
||||
- **`EndpointTypeInternalFQDN`** - Internal FQDN
|
||||
- **`EndpointTypeExternalIP`** - External IP address
|
||||
- **`EndpointTypeExternalFQDN`** - External FQDN
|
||||
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
|
||||
|
||||
### Auto-Scaling
|
||||
|
||||
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
|
||||
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
|
||||
|
||||
**Examples:**
|
||||
- Pool 100: 33 workers, 660 queue (capped at 500)
|
||||
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Redis sends push notifications about cluster maintenance operations
|
||||
2. Client creates new connections to updated endpoints
|
||||
3. Active operations transfer to new connections
|
||||
4. Old connections close gracefully
|
||||
|
||||
## Supported Notifications
|
||||
|
||||
- `MOVING` - Slot moving to new node
|
||||
- `MIGRATING` - Slot in migration state
|
||||
- `MIGRATED` - Migration completed
|
||||
- `FAILING_OVER` - Node failing over
|
||||
- `FAILED_OVER` - Failover completed
|
||||
|
||||
## Hooks (Optional)
|
||||
|
||||
Monitor and customize maintenance notification operations:
|
||||
|
||||
```go
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx, notificationCtx, notificationType, notification) ([]interface{}, bool)
|
||||
PostHook(ctx, notificationCtx, notificationType, notification, result)
|
||||
}
|
||||
|
||||
// Add custom hook
|
||||
manager.AddNotificationHook(&MyHook{})
|
||||
```
|
||||
|
||||
### Metrics Hook Example
|
||||
|
||||
```go
|
||||
// Create metrics hook
|
||||
metricsHook := maintnotifications.NewMetricsHook()
|
||||
manager.AddNotificationHook(metricsHook)
|
||||
|
||||
// Access collected metrics
|
||||
metrics := metricsHook.GetMetrics()
|
||||
fmt.Printf("Notification counts: %v\n", metrics["notification_counts"])
|
||||
fmt.Printf("Processing times: %v\n", metrics["processing_times"])
|
||||
fmt.Printf("Error counts: %v\n", metrics["error_counts"])
|
||||
```
|
||||
353
maintnotifications/circuit_breaker.go
Normal file
353
maintnotifications/circuit_breaker.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the state of a circuit breaker
|
||||
type CircuitBreakerState int32
|
||||
|
||||
const (
|
||||
// CircuitBreakerClosed - normal operation, requests allowed
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen - failing fast, requests rejected
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen - testing if service recovered
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
|
||||
type CircuitBreaker struct {
|
||||
// Configuration
|
||||
failureThreshold int // Number of failures before opening
|
||||
resetTimeout time.Duration // How long to stay open before testing
|
||||
maxRequests int // Max requests allowed in half-open state
|
||||
|
||||
// State tracking (atomic for lock-free access)
|
||||
state atomic.Int32 // CircuitBreakerState
|
||||
failures atomic.Int64 // Current failure count
|
||||
successes atomic.Int64 // Success count in half-open state
|
||||
requests atomic.Int64 // Request count in half-open state
|
||||
lastFailureTime atomic.Int64 // Unix timestamp of last failure
|
||||
lastSuccessTime atomic.Int64 // Unix timestamp of last success
|
||||
|
||||
// Endpoint identification
|
||||
endpoint string
|
||||
config *Config
|
||||
}
|
||||
|
||||
// newCircuitBreaker creates a new circuit breaker for an endpoint
|
||||
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
|
||||
// Use configuration values with sensible defaults
|
||||
failureThreshold := 5
|
||||
resetTimeout := 60 * time.Second
|
||||
maxRequests := 3
|
||||
|
||||
if config != nil {
|
||||
failureThreshold = config.CircuitBreakerFailureThreshold
|
||||
resetTimeout = config.CircuitBreakerResetTimeout
|
||||
maxRequests = config.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
failureThreshold: failureThreshold,
|
||||
resetTimeout: resetTimeout,
|
||||
maxRequests: maxRequests,
|
||||
endpoint: endpoint,
|
||||
config: config,
|
||||
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
|
||||
}
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is open (rejecting requests)
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
return state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// shouldAttemptReset checks if enough time has passed to attempt reset
|
||||
func (cb *CircuitBreaker) shouldAttemptReset() bool {
|
||||
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
|
||||
return time.Since(lastFailure) >= cb.resetTimeout
|
||||
}
|
||||
|
||||
// Execute runs the given function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
// Single atomic state load for consistency
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerOpen:
|
||||
if cb.shouldAttemptReset() {
|
||||
// Attempt transition to half-open
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
|
||||
cb.requests.Store(0)
|
||||
cb.successes.Store(0)
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint))
|
||||
}
|
||||
// Fall through to half-open logic
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
fallthrough
|
||||
case CircuitBreakerHalfOpen:
|
||||
requests := cb.requests.Add(1)
|
||||
if requests > int64(cb.maxRequests) {
|
||||
cb.requests.Add(-1) // Revert the increment
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the function with consistent state
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.lastFailureTime.Store(time.Now().Unix())
|
||||
failures := cb.failures.Add(1)
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
if failures >= int64(cb.failureThreshold) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures))
|
||||
}
|
||||
}
|
||||
}
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Any failure in half-open state immediately opens the circuit
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a success and potentially closes the circuit
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.lastSuccessTime.Store(time.Now().Unix())
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
// Reset failure count on success in closed state
|
||||
cb.failures.Store(0)
|
||||
case CircuitBreakerHalfOpen:
|
||||
successes := cb.successes.Add(1)
|
||||
|
||||
// If we've had enough successful requests, close the circuit
|
||||
if successes >= int64(cb.maxRequests) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
|
||||
cb.failures.Store(0)
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
return CircuitBreakerState(cb.state.Load())
|
||||
}
|
||||
|
||||
// GetStats returns current statistics for monitoring
|
||||
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
|
||||
return CircuitBreakerStats{
|
||||
Endpoint: cb.endpoint,
|
||||
State: cb.GetState(),
|
||||
Failures: cb.failures.Load(),
|
||||
Successes: cb.successes.Load(),
|
||||
Requests: cb.requests.Load(),
|
||||
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
|
||||
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats provides statistics about a circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
Endpoint string
|
||||
State CircuitBreakerState
|
||||
Failures int64
|
||||
Successes int64
|
||||
Requests int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerEntry wraps a circuit breaker with access tracking
|
||||
type CircuitBreakerEntry struct {
|
||||
breaker *CircuitBreaker
|
||||
lastAccess atomic.Int64 // Unix timestamp
|
||||
created time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerManager manages circuit breakers for multiple endpoints
|
||||
type CircuitBreakerManager struct {
|
||||
breakers sync.Map // map[string]*CircuitBreakerEntry
|
||||
config *Config
|
||||
cleanupStop chan struct{}
|
||||
cleanupMu sync.Mutex
|
||||
lastCleanup atomic.Int64 // Unix timestamp
|
||||
}
|
||||
|
||||
// newCircuitBreakerManager creates a new circuit breaker manager
|
||||
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
|
||||
cbm := &CircuitBreakerManager{
|
||||
config: config,
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
cbm.lastCleanup.Store(time.Now().Unix())
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go cbm.cleanupLoop()
|
||||
|
||||
return cbm
|
||||
}
|
||||
|
||||
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
|
||||
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
|
||||
now := time.Now().Unix()
|
||||
|
||||
if entry, ok := cbm.breakers.Load(endpoint); ok {
|
||||
cbEntry := entry.(*CircuitBreakerEntry)
|
||||
cbEntry.lastAccess.Store(now)
|
||||
return cbEntry.breaker
|
||||
}
|
||||
|
||||
// Create new circuit breaker with metadata
|
||||
newBreaker := newCircuitBreaker(endpoint, cbm.config)
|
||||
newEntry := &CircuitBreakerEntry{
|
||||
breaker: newBreaker,
|
||||
created: time.Now(),
|
||||
}
|
||||
newEntry.lastAccess.Store(now)
|
||||
|
||||
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
|
||||
return actual.(*CircuitBreakerEntry).breaker
|
||||
}
|
||||
|
||||
// GetAllStats returns statistics for all circuit breakers
|
||||
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
|
||||
var stats []CircuitBreakerStats
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
stats = append(stats, entry.breaker.GetStats())
|
||||
return true
|
||||
})
|
||||
return stats
|
||||
}
|
||||
|
||||
// cleanupLoop runs background cleanup of unused circuit breakers
|
||||
func (cbm *CircuitBreakerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cbm.cleanup()
|
||||
case <-cbm.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes circuit breakers that haven't been accessed recently
|
||||
func (cbm *CircuitBreakerManager) cleanup() {
|
||||
// Prevent concurrent cleanups
|
||||
if !cbm.cleanupMu.TryLock() {
|
||||
return
|
||||
}
|
||||
defer cbm.cleanupMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
|
||||
|
||||
var toDelete []string
|
||||
count := 0
|
||||
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
endpoint := key.(string)
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
|
||||
count++
|
||||
|
||||
// Remove if not accessed recently
|
||||
if entry.lastAccess.Load() < cutoff {
|
||||
toDelete = append(toDelete, endpoint)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete expired entries
|
||||
for _, endpoint := range toDelete {
|
||||
cbm.breakers.Delete(endpoint)
|
||||
}
|
||||
|
||||
// Log cleanup results
|
||||
if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count))
|
||||
}
|
||||
|
||||
cbm.lastCleanup.Store(now.Unix())
|
||||
}
|
||||
|
||||
// Shutdown stops the cleanup goroutine
|
||||
func (cbm *CircuitBreakerManager) Shutdown() {
|
||||
close(cbm.cleanupStop)
|
||||
}
|
||||
|
||||
// Reset resets all circuit breakers (useful for testing)
|
||||
func (cbm *CircuitBreakerManager) Reset() {
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
breaker := entry.breaker
|
||||
breaker.state.Store(int32(CircuitBreakerClosed))
|
||||
breaker.failures.Store(0)
|
||||
breaker.successes.Store(0)
|
||||
breaker.requests.Store(0)
|
||||
breaker.lastFailureTime.Store(0)
|
||||
breaker.lastSuccessTime.Store(0)
|
||||
return true
|
||||
})
|
||||
}
|
||||
348
maintnotifications/circuit_breaker_test.go
Normal file
348
maintnotifications/circuit_breaker_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
config := &Config{
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
|
||||
if cb.IsOpen() {
|
||||
t.Error("Circuit breaker should start in closed state")
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SuccessfulExecution", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
|
||||
err := cb.Execute(func() error {
|
||||
return nil // Success
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FailureThreshold", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Fail 4 times (below threshold of 5)
|
||||
for i := 0; i < 4; i++ {
|
||||
err := cb.Execute(func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Circuit should still be closed after %d failures", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 5th failure should open the circuit
|
||||
err := cb.Execute(func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenCircuitFailsFast", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
// Now it should fail fast
|
||||
err := cb.Execute(func() error {
|
||||
t.Error("Function should not be called when circuit is open")
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != ErrCircuitBreakerOpen {
|
||||
t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpenTransition", func(t *testing.T) {
|
||||
testConfig := &Config{
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open")
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Next call should transition to half-open
|
||||
executed := false
|
||||
err := cb.Execute(func() error {
|
||||
executed = true
|
||||
return nil // Success
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !executed {
|
||||
t.Error("Function should have been executed in half-open state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
|
||||
testConfig := &Config{
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 50 * time.Millisecond,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Execute successful requests in half-open state
|
||||
for i := 0; i < 3; i++ {
|
||||
err := cb.Execute(func() error {
|
||||
return nil // Success
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error on attempt %d, got %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Circuit should now be closed
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
|
||||
testConfig := &Config{
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 50 * time.Millisecond,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// First request in half-open state fails
|
||||
err := cb.Execute(func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
// Circuit should be open again
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Execute some operations
|
||||
cb.Execute(func() error { return testError }) // Failure
|
||||
cb.Execute(func() error { return testError }) // Failure
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
if stats.Endpoint != "test-endpoint:6379" {
|
||||
t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint)
|
||||
}
|
||||
|
||||
if stats.Failures != 2 {
|
||||
t.Errorf("Expected 2 failures, got %d", stats.Failures)
|
||||
}
|
||||
|
||||
if stats.State != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State)
|
||||
}
|
||||
|
||||
// Test that success resets failure count
|
||||
cb.Execute(func() error { return nil }) // Success
|
||||
stats = cb.GetStats()
|
||||
|
||||
if stats.Failures != 0 {
|
||||
t.Errorf("Expected 0 failures after success, got %d", stats.Failures)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCircuitBreakerManager(t *testing.T) {
|
||||
config := &Config{
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
|
||||
t.Run("GetCircuitBreaker", func(t *testing.T) {
|
||||
manager := newCircuitBreakerManager(config)
|
||||
|
||||
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
|
||||
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
|
||||
cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1
|
||||
|
||||
if cb1 == cb2 {
|
||||
t.Error("Different endpoints should have different circuit breakers")
|
||||
}
|
||||
|
||||
if cb1 != cb3 {
|
||||
t.Error("Same endpoint should return the same circuit breaker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetAllStats", func(t *testing.T) {
|
||||
manager := newCircuitBreakerManager(config)
|
||||
|
||||
// Create circuit breakers for different endpoints
|
||||
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
|
||||
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
|
||||
|
||||
// Execute some operations
|
||||
cb1.Execute(func() error { return nil })
|
||||
cb2.Execute(func() error { return errors.New("test error") })
|
||||
|
||||
stats := manager.GetAllStats()
|
||||
|
||||
if len(stats) != 2 {
|
||||
t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats))
|
||||
}
|
||||
|
||||
// Check that we have stats for both endpoints
|
||||
endpoints := make(map[string]bool)
|
||||
for _, stat := range stats {
|
||||
endpoints[stat.Endpoint] = true
|
||||
}
|
||||
|
||||
if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] {
|
||||
t.Error("Missing stats for expected endpoints")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Reset", func(t *testing.T) {
|
||||
manager := newCircuitBreakerManager(config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
cb := manager.GetCircuitBreaker("test-endpoint:6379")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open")
|
||||
}
|
||||
|
||||
// Reset all circuit breakers
|
||||
manager.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Circuit should be closed after reset")
|
||||
}
|
||||
|
||||
if cb.failures.Load() != 0 {
|
||||
t.Error("Failure count should be reset to 0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigurableParameters", func(t *testing.T) {
|
||||
config := &Config{
|
||||
CircuitBreakerFailureThreshold: 10,
|
||||
CircuitBreakerResetTimeout: 30 * time.Second,
|
||||
CircuitBreakerMaxRequests: 5,
|
||||
}
|
||||
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
|
||||
// Test that configuration values are used
|
||||
if cb.failureThreshold != 10 {
|
||||
t.Errorf("Expected failureThreshold=10, got %d", cb.failureThreshold)
|
||||
}
|
||||
if cb.resetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected resetTimeout=30s, got %v", cb.resetTimeout)
|
||||
}
|
||||
if cb.maxRequests != 5 {
|
||||
t.Errorf("Expected maxRequests=5, got %d", cb.maxRequests)
|
||||
}
|
||||
|
||||
// Test that circuit opens after configured threshold
|
||||
testError := errors.New("test error")
|
||||
for i := 0; i < 9; i++ {
|
||||
err := cb.Execute(func() error { return testError })
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Circuit should still be closed after %d failures", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 10th failure should open the circuit
|
||||
err := cb.Execute(func() error { return testError })
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
|
||||
}
|
||||
})
|
||||
}
|
||||
458
maintnotifications/config.go
Normal file
458
maintnotifications/config.go
Normal file
@@ -0,0 +1,458 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
// Mode represents the maintenance notifications mode
|
||||
type Mode string
|
||||
|
||||
// Constants for maintenance push notifications modes
|
||||
const (
|
||||
ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
|
||||
ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error
|
||||
ModeAuto Mode = "auto" // Client tries to send command, disables feature on error
|
||||
)
|
||||
|
||||
// IsValid returns true if the maintenance notifications mode is valid
|
||||
func (m Mode) IsValid() bool {
|
||||
switch m {
|
||||
case ModeDisabled, ModeEnabled, ModeAuto:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the mode
|
||||
func (m Mode) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
// EndpointType represents the type of endpoint to request in MOVING notifications
|
||||
type EndpointType string
|
||||
|
||||
// Constants for endpoint types
|
||||
const (
|
||||
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
|
||||
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
|
||||
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
|
||||
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
|
||||
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
|
||||
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
|
||||
)
|
||||
|
||||
// IsValid returns true if the endpoint type is valid
|
||||
func (e EndpointType) IsValid() bool {
|
||||
switch e {
|
||||
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the endpoint type
|
||||
func (e EndpointType) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Config provides configuration options for maintenance notifications
|
||||
type Config struct {
|
||||
// Mode controls how client maintenance notifications are handled.
|
||||
// Valid values: ModeDisabled, ModeEnabled, ModeAuto
|
||||
// Default: ModeAuto
|
||||
Mode Mode
|
||||
|
||||
// EndpointType specifies the type of endpoint to request in MOVING notifications.
|
||||
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
|
||||
// Default: EndpointTypeAuto
|
||||
EndpointType EndpointType
|
||||
|
||||
// RelaxedTimeout is the concrete timeout value to use during
|
||||
// MIGRATING/FAILING_OVER states to accommodate increased latency.
|
||||
// This applies to both read and write timeouts.
|
||||
// Default: 10 seconds
|
||||
RelaxedTimeout time.Duration
|
||||
|
||||
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
|
||||
// If handoff takes longer than this, the old connection will be forcibly closed.
|
||||
// Default: 15 seconds (matches server-side eviction timeout)
|
||||
HandoffTimeout time.Duration
|
||||
|
||||
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
|
||||
// Workers are created on-demand and automatically cleaned up when idle.
|
||||
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
|
||||
// If explicitly set, enforces minimum of PoolSize/2
|
||||
//
|
||||
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
|
||||
MaxWorkers int
|
||||
|
||||
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
|
||||
// If the queue is full, new handoff requests will be rejected.
|
||||
// Scales with both worker count and pool size for better burst handling.
|
||||
//
|
||||
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
HandoffQueueSize int
|
||||
|
||||
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
|
||||
// after a handoff completes. This provides additional resilience during cluster transitions.
|
||||
// Default: 2 * RelaxedTimeout
|
||||
PostHandoffRelaxedDuration time.Duration
|
||||
|
||||
// Circuit breaker configuration for endpoint failure handling
|
||||
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
|
||||
// Default: 5
|
||||
CircuitBreakerFailureThreshold int
|
||||
|
||||
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
|
||||
// Default: 60 seconds
|
||||
CircuitBreakerResetTimeout time.Duration
|
||||
|
||||
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
|
||||
// Default: 3
|
||||
CircuitBreakerMaxRequests int
|
||||
|
||||
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
|
||||
// After this many retries, the connection will be removed from the pool.
|
||||
// Default: 3
|
||||
MaxHandoffRetries int
|
||||
}
|
||||
|
||||
func (c *Config) IsEnabled() bool {
|
||||
return c != nil && c.Mode != ModeDisabled
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Mode: ModeAuto, // Enable by default for Redis Cloud
|
||||
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: 0, // Auto-calculated based on pool size
|
||||
HandoffQueueSize: 0, // Auto-calculated based on max workers
|
||||
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
|
||||
// Connection Handoff Configuration
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RelaxedTimeout <= 0 {
|
||||
return ErrInvalidRelaxedTimeout
|
||||
}
|
||||
if c.HandoffTimeout <= 0 {
|
||||
return ErrInvalidHandoffTimeout
|
||||
}
|
||||
// Validate worker configuration
|
||||
// Allow 0 for auto-calculation, but negative values are invalid
|
||||
if c.MaxWorkers < 0 {
|
||||
return ErrInvalidHandoffWorkers
|
||||
}
|
||||
// HandoffQueueSize validation - allow 0 for auto-calculation
|
||||
if c.HandoffQueueSize < 0 {
|
||||
return ErrInvalidHandoffQueueSize
|
||||
}
|
||||
if c.PostHandoffRelaxedDuration < 0 {
|
||||
return ErrInvalidPostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// Circuit breaker validation
|
||||
if c.CircuitBreakerFailureThreshold < 1 {
|
||||
return ErrInvalidCircuitBreakerFailureThreshold
|
||||
}
|
||||
if c.CircuitBreakerResetTimeout < 0 {
|
||||
return ErrInvalidCircuitBreakerResetTimeout
|
||||
}
|
||||
if c.CircuitBreakerMaxRequests < 1 {
|
||||
return ErrInvalidCircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
// Validate Mode (maintenance notifications mode)
|
||||
if !c.Mode.IsValid() {
|
||||
return ErrInvalidMaintNotifications
|
||||
}
|
||||
|
||||
// Validate EndpointType
|
||||
if !c.EndpointType.IsValid() {
|
||||
return ErrInvalidEndpointType
|
||||
}
|
||||
|
||||
// Validate configuration fields
|
||||
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
|
||||
return ErrInvalidHandoffRetries
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values to any zero-value fields in the configuration.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaults() *Config {
|
||||
return c.ApplyDefaultsWithPoolSize(0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size to calculate worker defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
|
||||
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size and max active connections to calculate worker and queue defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
defaults := DefaultConfig()
|
||||
result := &Config{}
|
||||
|
||||
// Apply defaults for enum fields (empty/zero means not set)
|
||||
result.Mode = defaults.Mode
|
||||
if c.Mode != "" {
|
||||
result.Mode = c.Mode
|
||||
}
|
||||
|
||||
result.EndpointType = defaults.EndpointType
|
||||
if c.EndpointType != "" {
|
||||
result.EndpointType = c.EndpointType
|
||||
}
|
||||
|
||||
// Apply defaults for duration fields (zero means not set)
|
||||
result.RelaxedTimeout = defaults.RelaxedTimeout
|
||||
if c.RelaxedTimeout > 0 {
|
||||
result.RelaxedTimeout = c.RelaxedTimeout
|
||||
}
|
||||
|
||||
result.HandoffTimeout = defaults.HandoffTimeout
|
||||
if c.HandoffTimeout > 0 {
|
||||
result.HandoffTimeout = c.HandoffTimeout
|
||||
}
|
||||
|
||||
// Copy worker configuration
|
||||
result.MaxWorkers = c.MaxWorkers
|
||||
|
||||
// Apply worker defaults based on pool size
|
||||
result.applyWorkerDefaults(poolSize)
|
||||
|
||||
// Apply queue size defaults with new scaling approach
|
||||
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolBasedSize := poolSize
|
||||
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
|
||||
if c.HandoffQueueSize > 0 {
|
||||
// When explicitly set: enforce minimum of 200
|
||||
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
|
||||
var queueCap int
|
||||
if maxActiveConns > 0 {
|
||||
queueCap = maxActiveConns + 1
|
||||
// Ensure queue cap is at least 2 for very small maxActiveConns
|
||||
if queueCap < 2 {
|
||||
queueCap = 2
|
||||
}
|
||||
} else {
|
||||
queueCap = poolSize * 5
|
||||
}
|
||||
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
|
||||
|
||||
// Ensure minimum queue size of 2 (fallback for very small pools)
|
||||
if result.HandoffQueueSize < 2 {
|
||||
result.HandoffQueueSize = 2
|
||||
}
|
||||
|
||||
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
|
||||
if c.PostHandoffRelaxedDuration > 0 {
|
||||
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// Apply defaults for configuration fields
|
||||
result.MaxHandoffRetries = defaults.MaxHandoffRetries
|
||||
if c.MaxHandoffRetries > 0 {
|
||||
result.MaxHandoffRetries = c.MaxHandoffRetries
|
||||
}
|
||||
|
||||
// Circuit breaker configuration
|
||||
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
|
||||
if c.CircuitBreakerFailureThreshold > 0 {
|
||||
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
|
||||
}
|
||||
|
||||
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
|
||||
if c.CircuitBreakerResetTimeout > 0 {
|
||||
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
|
||||
}
|
||||
|
||||
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
|
||||
if c.CircuitBreakerMaxRequests > 0 {
|
||||
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled())
|
||||
internal.Logger.Printf(context.Background(), logs.ConfigDebug(result))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration.
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Mode: c.Mode,
|
||||
EndpointType: c.EndpointType,
|
||||
RelaxedTimeout: c.RelaxedTimeout,
|
||||
HandoffTimeout: c.HandoffTimeout,
|
||||
MaxWorkers: c.MaxWorkers,
|
||||
HandoffQueueSize: c.HandoffQueueSize,
|
||||
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
|
||||
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
|
||||
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
|
||||
|
||||
// Configuration fields
|
||||
MaxHandoffRetries: c.MaxHandoffRetries,
|
||||
}
|
||||
}
|
||||
|
||||
// applyWorkerDefaults calculates and applies worker defaults based on pool size
|
||||
func (c *Config) applyWorkerDefaults(poolSize int) {
|
||||
// Calculate defaults based on pool size
|
||||
if poolSize <= 0 {
|
||||
poolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
|
||||
originalMaxWorkers := c.MaxWorkers
|
||||
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
|
||||
if originalMaxWorkers != 0 {
|
||||
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
|
||||
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
|
||||
}
|
||||
|
||||
// Ensure minimum of 1 worker (fallback for very small pools)
|
||||
if c.MaxWorkers < 1 {
|
||||
c.MaxWorkers = 1
|
||||
}
|
||||
}
|
||||
|
||||
// DetectEndpointType automatically detects the appropriate endpoint type
|
||||
// based on the connection address and TLS configuration.
|
||||
//
|
||||
// For IP addresses:
|
||||
// - If TLS is enabled: requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// For hostnames:
|
||||
// - If TLS is enabled: always requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// Internal vs External detection:
|
||||
// - For IPs: uses private IP range detection
|
||||
// - For hostnames: uses heuristics based on common internal naming patterns
|
||||
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
|
||||
// Extract host from "host:port" format
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr // Assume no port
|
||||
}
|
||||
|
||||
// Check if the host is an IP address or hostname
|
||||
ip := net.ParseIP(host)
|
||||
isIPAddress := ip != nil
|
||||
var endpointType EndpointType
|
||||
|
||||
if isIPAddress {
|
||||
// Address is an IP - determine if it's private or public
|
||||
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||
|
||||
if tlsEnabled {
|
||||
// TLS with IP addresses - still prefer FQDN for certificate validation
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
} else {
|
||||
// No TLS - can use IP addresses directly
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalIP
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalIP
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Address is a hostname
|
||||
isInternalHostname := isInternalHostname(host)
|
||||
if isInternalHostname {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
}
|
||||
|
||||
return endpointType
|
||||
}
|
||||
|
||||
// isInternalHostname determines if a hostname appears to be internal/private.
|
||||
// This is a heuristic based on common naming patterns.
|
||||
func isInternalHostname(hostname string) bool {
|
||||
// Convert to lowercase for comparison
|
||||
hostname = strings.ToLower(hostname)
|
||||
|
||||
// Common internal hostname patterns
|
||||
internalPatterns := []string{
|
||||
"localhost",
|
||||
".local",
|
||||
".internal",
|
||||
".corp",
|
||||
".lan",
|
||||
".intranet",
|
||||
".private",
|
||||
}
|
||||
|
||||
// Check for exact match or suffix match
|
||||
for _, pattern := range internalPatterns {
|
||||
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
|
||||
// If hostname doesn't contain dots, it's likely internal
|
||||
if !strings.Contains(hostname, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Default to external for fully qualified domain names
|
||||
return false
|
||||
}
|
||||
481
maintnotifications/config_test.go
Normal file
481
maintnotifications/config_test.go
Normal file
@@ -0,0 +1,481 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("DefaultConfig", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
|
||||
// MaxWorkers should be 0 in default config (auto-calculated)
|
||||
if config.MaxWorkers != 0 {
|
||||
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be 0 in default config (auto-calculated)
|
||||
if config.HandoffQueueSize != 0 {
|
||||
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if config.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
|
||||
}
|
||||
|
||||
// Test configuration fields have proper defaults
|
||||
if config.MaxHandoffRetries != 3 {
|
||||
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
|
||||
}
|
||||
|
||||
// Circuit breaker defaults
|
||||
if config.CircuitBreakerFailureThreshold != 5 {
|
||||
t.Errorf("Expected CircuitBreakerFailureThreshold=5, got %d", config.CircuitBreakerFailureThreshold)
|
||||
}
|
||||
if config.CircuitBreakerResetTimeout != 60*time.Second {
|
||||
t.Errorf("Expected CircuitBreakerResetTimeout=60s, got %v", config.CircuitBreakerResetTimeout)
|
||||
}
|
||||
if config.CircuitBreakerMaxRequests != 3 {
|
||||
t.Errorf("Expected CircuitBreakerMaxRequests=3, got %d", config.CircuitBreakerMaxRequests)
|
||||
}
|
||||
|
||||
if config.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
|
||||
}
|
||||
|
||||
if config.PostHandoffRelaxedDuration != 0 {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
|
||||
}
|
||||
|
||||
// Test that defaults are applied correctly
|
||||
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
|
||||
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigValidation", func(t *testing.T) {
|
||||
// Valid config with applied defaults
|
||||
config := DefaultConfig().ApplyDefaults()
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Default config with applied defaults should be valid: %v", err)
|
||||
}
|
||||
|
||||
// Invalid worker configuration (negative MaxWorkers)
|
||||
config = &Config{
|
||||
RelaxedTimeout: 30 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: -1, // This should be invalid
|
||||
HandoffQueueSize: 100,
|
||||
PostHandoffRelaxedDuration: 10 * time.Second,
|
||||
MaxHandoffRetries: 3, // Add required field
|
||||
}
|
||||
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
|
||||
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid HandoffQueueSize
|
||||
config = DefaultConfig().ApplyDefaults()
|
||||
config.HandoffQueueSize = -1
|
||||
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
|
||||
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid PostHandoffRelaxedDuration
|
||||
config = DefaultConfig().ApplyDefaults()
|
||||
config.PostHandoffRelaxedDuration = -1 * time.Second
|
||||
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
|
||||
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigClone", func(t *testing.T) {
|
||||
original := DefaultConfig()
|
||||
original.MaxWorkers = 20
|
||||
original.HandoffQueueSize = 200
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
if cloned.MaxWorkers != 20 {
|
||||
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
|
||||
}
|
||||
|
||||
if cloned.HandoffQueueSize != 200 {
|
||||
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Modify original to ensure clone is independent
|
||||
original.MaxWorkers = 2
|
||||
if cloned.MaxWorkers != 20 {
|
||||
t.Error("Clone should be independent of original")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyDefaults(t *testing.T) {
|
||||
t.Run("NilConfig", func(t *testing.T) {
|
||||
var config *Config
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// With nil config, should get default config with auto-calculated workers
|
||||
if result.MaxWorkers <= 0 {
|
||||
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be auto-calculated with hybrid scaling
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialConfig", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 60, // Set this field explicitly (> poolSize/2 = 50)
|
||||
// Leave other fields as zero values
|
||||
}
|
||||
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Should keep the explicitly set values when > poolSize/2
|
||||
if result.MaxWorkers != 60 {
|
||||
t.Errorf("Expected MaxWorkers to be 60 (explicitly set), got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// Should apply default for unset fields (auto-calculated queue size with hybrid scaling)
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test explicit queue size capping by 5x pool size
|
||||
configWithLargeQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 1000, // Much larger than 5x pool size
|
||||
}
|
||||
|
||||
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
|
||||
expectedCap := 20 * 5 // 5x pool size = 100
|
||||
if resultCapped.HandoffQueueSize != expectedCap {
|
||||
t.Errorf("Expected HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedCap, resultCapped.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test explicit queue size minimum enforcement
|
||||
configWithSmallQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 10, // Below minimum of 200
|
||||
}
|
||||
|
||||
resultMinimum := configWithSmallQueue.ApplyDefaultsWithPoolSize(100) // Large pool size
|
||||
if resultMinimum.HandoffQueueSize != 200 {
|
||||
t.Errorf("Expected HandoffQueueSize to be enforced minimum (200), got %d", resultMinimum.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test that large explicit values are capped by 5x pool size
|
||||
configWithVeryLargeQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 1000, // Much larger than 5x pool size
|
||||
}
|
||||
|
||||
resultVeryLarge := configWithVeryLargeQueue.ApplyDefaultsWithPoolSize(100) // Pool size 100
|
||||
expectedVeryLargeCap := 100 * 5 // 5x pool size = 500
|
||||
if resultVeryLarge.HandoffQueueSize != expectedVeryLargeCap {
|
||||
t.Errorf("Expected very large HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedVeryLargeCap, resultVeryLarge.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if result.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||
}
|
||||
|
||||
if result.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ZeroValues", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 0, // Zero value should get auto-calculated defaults
|
||||
HandoffQueueSize: 0, // Zero value should get default
|
||||
RelaxedTimeout: 0, // Zero value should get default
|
||||
}
|
||||
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Zero values should get auto-calculated defaults
|
||||
if result.MaxWorkers <= 0 {
|
||||
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be auto-calculated with hybrid scaling
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if result.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||
}
|
||||
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessorWithConfig(t *testing.T) {
|
||||
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 50,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// The processor should be created successfully with custom config
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with custom config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 7, // Only set worker field
|
||||
// Other fields will get defaults
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Should work with partial config (defaults applied)
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with partial config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Should use default config when nil is passed
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with nil config (using defaults)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationWithApplyDefaults(t *testing.T) {
|
||||
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
|
||||
// Create a partial config with only some fields set
|
||||
partialConfig := &Config{
|
||||
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
|
||||
// Other fields left as zero values - should get defaults
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
// Create processor - should apply defaults to missing fields
|
||||
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Processor should be created successfully
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with partial config")
|
||||
}
|
||||
|
||||
// Test that the ApplyDefaults method worked correctly by creating the same config
|
||||
// and applying defaults manually
|
||||
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Should preserve custom values (when >= poolSize/2)
|
||||
if expectedConfig.MaxWorkers != 50 { // max(poolSize/2, 15) = max(50, 15) = 50
|
||||
t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling)
|
||||
workerBasedSize := expectedConfig.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if expectedConfig.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, expectedConfig.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test that queue size is always capped by 5x pool size
|
||||
if expectedConfig.HandoffQueueSize > poolSize*5 {
|
||||
t.Errorf("HandoffQueueSize (%d) should never exceed 5x pool size (%d)",
|
||||
expectedConfig.HandoffQueueSize, poolSize*2)
|
||||
}
|
||||
|
||||
if expectedConfig.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
|
||||
}
|
||||
|
||||
if expectedConfig.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
|
||||
}
|
||||
|
||||
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnhancedConfigValidation(t *testing.T) {
|
||||
t.Run("ValidateFields", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
|
||||
|
||||
// Should pass validation with default values
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Default config should be valid, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid MaxHandoffRetries
|
||||
config.MaxHandoffRetries = 0
|
||||
if err := config.Validate(); err == nil {
|
||||
t.Error("Expected validation error for MaxHandoffRetries = 0")
|
||||
}
|
||||
config.MaxHandoffRetries = 11
|
||||
if err := config.Validate(); err == nil {
|
||||
t.Error("Expected validation error for MaxHandoffRetries = 11")
|
||||
}
|
||||
config.MaxHandoffRetries = 3 // Reset to valid value
|
||||
|
||||
// Test circuit breaker validation
|
||||
config.CircuitBreakerFailureThreshold = 0
|
||||
if err := config.Validate(); err != ErrInvalidCircuitBreakerFailureThreshold {
|
||||
t.Errorf("Expected ErrInvalidCircuitBreakerFailureThreshold, got %v", err)
|
||||
}
|
||||
config.CircuitBreakerFailureThreshold = 5 // Reset to valid value
|
||||
|
||||
config.CircuitBreakerResetTimeout = -1 * time.Second
|
||||
if err := config.Validate(); err != ErrInvalidCircuitBreakerResetTimeout {
|
||||
t.Errorf("Expected ErrInvalidCircuitBreakerResetTimeout, got %v", err)
|
||||
}
|
||||
config.CircuitBreakerResetTimeout = 60 * time.Second // Reset to valid value
|
||||
|
||||
config.CircuitBreakerMaxRequests = 0
|
||||
if err := config.Validate(); err != ErrInvalidCircuitBreakerMaxRequests {
|
||||
t.Errorf("Expected ErrInvalidCircuitBreakerMaxRequests, got %v", err)
|
||||
}
|
||||
config.CircuitBreakerMaxRequests = 3 // Reset to valid value
|
||||
|
||||
// Should pass validation again
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Config should be valid after reset, got error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigClone(t *testing.T) {
|
||||
original := DefaultConfig()
|
||||
original.MaxHandoffRetries = 7
|
||||
original.HandoffTimeout = 8 * time.Second
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
// Test that values are copied
|
||||
if cloned.MaxHandoffRetries != 7 {
|
||||
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
|
||||
}
|
||||
if cloned.HandoffTimeout != 8*time.Second {
|
||||
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
|
||||
}
|
||||
|
||||
// Test that modifying clone doesn't affect original
|
||||
cloned.MaxHandoffRetries = 10
|
||||
if original.MaxHandoffRetries != 7 {
|
||||
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxWorkersLogic(t *testing.T) {
|
||||
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
poolSize int
|
||||
expectedWorkers int
|
||||
description string
|
||||
}{
|
||||
{6, 3, "Small pool: min(6/2, max(10, 6/3)) = min(3, max(10, 2)) = min(3, 10) = 3"},
|
||||
{15, 7, "Medium pool: min(15/2, max(10, 15/3)) = min(7, max(10, 5)) = min(7, 10) = 7"},
|
||||
{30, 10, "Large pool: min(30/2, max(10, 30/3)) = min(15, max(10, 10)) = min(15, 10) = 10"},
|
||||
{60, 20, "Very large pool: min(60/2, max(10, 60/3)) = min(30, max(10, 20)) = min(30, 20) = 20"},
|
||||
{120, 40, "Huge pool: min(120/2, max(10, 120/3)) = min(60, max(10, 40)) = min(60, 40) = 40"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
config := &Config{} // MaxWorkers = 0 (not set)
|
||||
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
|
||||
|
||||
if result.MaxWorkers != tc.expectedWorkers {
|
||||
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
|
||||
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
setValue int
|
||||
expectedWorkers int
|
||||
description string
|
||||
}{
|
||||
{1, 50, "Set 1: max(poolSize/2, 1) = max(50, 1) = 50 (enforced minimum)"},
|
||||
{5, 50, "Set 5: max(poolSize/2, 5) = max(50, 5) = 50 (enforced minimum)"},
|
||||
{8, 50, "Set 8: max(poolSize/2, 8) = max(50, 8) = 50 (enforced minimum)"},
|
||||
{10, 50, "Set 10: max(poolSize/2, 10) = max(50, 10) = 50 (enforced minimum)"},
|
||||
{15, 50, "Set 15: max(poolSize/2, 15) = max(50, 15) = 50 (enforced minimum)"},
|
||||
{60, 60, "Set 60: max(poolSize/2, 60) = max(50, 60) = 60 (respects user choice)"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
config := &Config{
|
||||
MaxWorkers: tc.setValue, // Explicitly set
|
||||
}
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
|
||||
|
||||
if result.MaxWorkers != tc.expectedWorkers {
|
||||
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
|
||||
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
30
maintnotifications/e2e/.gitignore
vendored
Normal file
30
maintnotifications/e2e/.gitignore
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# E2E test artifacts
|
||||
*.log
|
||||
*.out
|
||||
test-results/
|
||||
coverage/
|
||||
profiles/
|
||||
|
||||
# Test data
|
||||
test-data/
|
||||
temp/
|
||||
*.tmp
|
||||
|
||||
# CI artifacts
|
||||
artifacts/
|
||||
reports/
|
||||
|
||||
# Redis data files (if running local Redis for testing)
|
||||
dump.rdb
|
||||
appendonly.aof
|
||||
redis.conf.local
|
||||
|
||||
# Performance test results
|
||||
*.prof
|
||||
*.trace
|
||||
benchmarks/
|
||||
|
||||
# Docker compose files for local testing
|
||||
docker-compose.override.yml
|
||||
.env.local
|
||||
infra/
|
||||
363
maintnotifications/e2e/DATABASE_MANAGEMENT.md
Normal file
363
maintnotifications/e2e/DATABASE_MANAGEMENT.md
Normal file
@@ -0,0 +1,363 @@
|
||||
# Database Management with Fault Injector
|
||||
|
||||
This document describes how to use the fault injector's database management endpoints to create and delete Redis databases during E2E testing.
|
||||
|
||||
## Overview
|
||||
|
||||
The fault injector now supports two new endpoints for database management:
|
||||
|
||||
1. **CREATE_DATABASE** - Create a new Redis database with custom configuration
|
||||
2. **DELETE_DATABASE** - Delete an existing Redis database
|
||||
|
||||
These endpoints are useful for E2E tests that need to dynamically create and destroy databases as part of their test scenarios.
|
||||
|
||||
## Action Types
|
||||
|
||||
### CREATE_DATABASE
|
||||
|
||||
Creates a new Redis database with the specified configuration.
|
||||
|
||||
**Parameters:**
|
||||
- `cluster_index` (int): The index of the cluster where the database should be created
|
||||
- `database_config` (object): The database configuration (see structure below)
|
||||
|
||||
**Raises:**
|
||||
- `CreateDatabaseException`: When database creation fails
|
||||
|
||||
### DELETE_DATABASE
|
||||
|
||||
Deletes an existing Redis database.
|
||||
|
||||
**Parameters:**
|
||||
- `cluster_index` (int): The index of the cluster containing the database
|
||||
- `bdb_id` (int): The database ID to delete
|
||||
|
||||
**Raises:**
|
||||
- `DeleteDatabaseException`: When database deletion fails
|
||||
|
||||
## Database Configuration Structure
|
||||
|
||||
The `database_config` object supports the following fields:
|
||||
|
||||
```go
|
||||
type DatabaseConfig struct {
|
||||
Name string `json:"name"`
|
||||
Port int `json:"port"`
|
||||
MemorySize int64 `json:"memory_size"`
|
||||
Replication bool `json:"replication"`
|
||||
EvictionPolicy string `json:"eviction_policy"`
|
||||
Sharding bool `json:"sharding"`
|
||||
AutoUpgrade bool `json:"auto_upgrade"`
|
||||
ShardsCount int `json:"shards_count"`
|
||||
ModuleList []DatabaseModule `json:"module_list,omitempty"`
|
||||
OSSCluster bool `json:"oss_cluster"`
|
||||
OSSClusterAPIPreferredIPType string `json:"oss_cluster_api_preferred_ip_type,omitempty"`
|
||||
ProxyPolicy string `json:"proxy_policy,omitempty"`
|
||||
ShardsPlacement string `json:"shards_placement,omitempty"`
|
||||
ShardKeyRegex []ShardKeyRegexPattern `json:"shard_key_regex,omitempty"`
|
||||
}
|
||||
|
||||
type DatabaseModule struct {
|
||||
ModuleArgs string `json:"module_args"`
|
||||
ModuleName string `json:"module_name"`
|
||||
}
|
||||
|
||||
type ShardKeyRegexPattern struct {
|
||||
Regex string `json:"regex"`
|
||||
}
|
||||
```
|
||||
|
||||
### Example Configuration
|
||||
|
||||
#### Simple Database
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "simple-db",
|
||||
"port": 12000,
|
||||
"memory_size": 268435456,
|
||||
"replication": false,
|
||||
"eviction_policy": "noeviction",
|
||||
"sharding": false,
|
||||
"auto_upgrade": true,
|
||||
"shards_count": 1,
|
||||
"oss_cluster": false
|
||||
}
|
||||
```
|
||||
|
||||
#### Clustered Database with Modules
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "ioredis-cluster",
|
||||
"port": 11112,
|
||||
"memory_size": 1273741824,
|
||||
"replication": true,
|
||||
"eviction_policy": "noeviction",
|
||||
"sharding": true,
|
||||
"auto_upgrade": true,
|
||||
"shards_count": 3,
|
||||
"module_list": [
|
||||
{
|
||||
"module_args": "",
|
||||
"module_name": "ReJSON"
|
||||
},
|
||||
{
|
||||
"module_args": "",
|
||||
"module_name": "search"
|
||||
},
|
||||
{
|
||||
"module_args": "",
|
||||
"module_name": "timeseries"
|
||||
},
|
||||
{
|
||||
"module_args": "",
|
||||
"module_name": "bf"
|
||||
}
|
||||
],
|
||||
"oss_cluster": true,
|
||||
"oss_cluster_api_preferred_ip_type": "external",
|
||||
"proxy_policy": "all-master-shards",
|
||||
"shards_placement": "sparse",
|
||||
"shard_key_regex": [
|
||||
{
|
||||
"regex": ".*\\{(?<tag>.*)\\}.*"
|
||||
},
|
||||
{
|
||||
"regex": "(?<tag>.*)"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Example 1: Create a Simple Database
|
||||
|
||||
```go
|
||||
ctx := context.Background()
|
||||
faultInjector := NewFaultInjectorClient("http://127.0.0.1:20324")
|
||||
|
||||
dbConfig := DatabaseConfig{
|
||||
Name: "test-db",
|
||||
Port: 12000,
|
||||
MemorySize: 268435456, // 256MB
|
||||
Replication: false,
|
||||
EvictionPolicy: "noeviction",
|
||||
Sharding: false,
|
||||
AutoUpgrade: true,
|
||||
ShardsCount: 1,
|
||||
OSSCluster: false,
|
||||
}
|
||||
|
||||
resp, err := faultInjector.CreateDatabase(ctx, 0, dbConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
// Wait for creation to complete
|
||||
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
|
||||
WithMaxWaitTime(5*time.Minute))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to wait for action: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == StatusSuccess {
|
||||
log.Println("Database created successfully!")
|
||||
}
|
||||
```
|
||||
|
||||
### Example 2: Create a Database with Modules
|
||||
|
||||
```go
|
||||
dbConfig := DatabaseConfig{
|
||||
Name: "modules-db",
|
||||
Port: 12001,
|
||||
MemorySize: 536870912, // 512MB
|
||||
Replication: true,
|
||||
EvictionPolicy: "noeviction",
|
||||
Sharding: true,
|
||||
AutoUpgrade: true,
|
||||
ShardsCount: 3,
|
||||
ModuleList: []DatabaseModule{
|
||||
{ModuleArgs: "", ModuleName: "ReJSON"},
|
||||
{ModuleArgs: "", ModuleName: "search"},
|
||||
},
|
||||
OSSCluster: true,
|
||||
OSSClusterAPIPreferredIPType: "external",
|
||||
ProxyPolicy: "all-master-shards",
|
||||
ShardsPlacement: "sparse",
|
||||
}
|
||||
|
||||
resp, err := faultInjector.CreateDatabase(ctx, 0, dbConfig)
|
||||
// ... handle response
|
||||
```
|
||||
|
||||
### Example 3: Create Database Using a Map
|
||||
|
||||
```go
|
||||
dbConfigMap := map[string]interface{}{
|
||||
"name": "map-db",
|
||||
"port": 12002,
|
||||
"memory_size": 268435456,
|
||||
"replication": false,
|
||||
"eviction_policy": "volatile-lru",
|
||||
"sharding": false,
|
||||
"auto_upgrade": true,
|
||||
"shards_count": 1,
|
||||
"oss_cluster": false,
|
||||
}
|
||||
|
||||
resp, err := faultInjector.CreateDatabaseFromMap(ctx, 0, dbConfigMap)
|
||||
// ... handle response
|
||||
```
|
||||
|
||||
### Example 4: Delete a Database
|
||||
|
||||
```go
|
||||
clusterIndex := 0
|
||||
bdbID := 1
|
||||
|
||||
resp, err := faultInjector.DeleteDatabase(ctx, clusterIndex, bdbID)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to delete database: %v", err)
|
||||
}
|
||||
|
||||
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
|
||||
WithMaxWaitTime(2*time.Minute))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to wait for action: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == StatusSuccess {
|
||||
log.Println("Database deleted successfully!")
|
||||
}
|
||||
```
|
||||
|
||||
### Example 5: Complete Lifecycle (Create and Delete)
|
||||
|
||||
```go
|
||||
// Create database
|
||||
dbConfig := DatabaseConfig{
|
||||
Name: "temp-db",
|
||||
Port: 13000,
|
||||
MemorySize: 268435456,
|
||||
Replication: false,
|
||||
EvictionPolicy: "noeviction",
|
||||
Sharding: false,
|
||||
AutoUpgrade: true,
|
||||
ShardsCount: 1,
|
||||
OSSCluster: false,
|
||||
}
|
||||
|
||||
createResp, err := faultInjector.CreateDatabase(ctx, 0, dbConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
createStatus, err := faultInjector.WaitForAction(ctx, createResp.ActionID,
|
||||
WithMaxWaitTime(5*time.Minute))
|
||||
if err != nil || createStatus.Status != StatusSuccess {
|
||||
log.Fatalf("Database creation failed")
|
||||
}
|
||||
|
||||
// Extract bdb_id from output
|
||||
var bdbID int
|
||||
if id, ok := createStatus.Output["bdb_id"].(float64); ok {
|
||||
bdbID = int(id)
|
||||
}
|
||||
|
||||
// Use the database for testing...
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Delete the database
|
||||
deleteResp, err := faultInjector.DeleteDatabase(ctx, 0, bdbID)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to delete database: %v", err)
|
||||
}
|
||||
|
||||
deleteStatus, err := faultInjector.WaitForAction(ctx, deleteResp.ActionID,
|
||||
WithMaxWaitTime(2*time.Minute))
|
||||
if err != nil || deleteStatus.Status != StatusSuccess {
|
||||
log.Fatalf("Database deletion failed")
|
||||
}
|
||||
|
||||
log.Println("Database lifecycle completed successfully!")
|
||||
```
|
||||
|
||||
## Available Methods
|
||||
|
||||
The `FaultInjectorClient` provides the following methods for database management:
|
||||
|
||||
### CreateDatabase
|
||||
|
||||
```go
|
||||
func (c *FaultInjectorClient) CreateDatabase(
|
||||
ctx context.Context,
|
||||
clusterIndex int,
|
||||
databaseConfig DatabaseConfig,
|
||||
) (*ActionResponse, error)
|
||||
```
|
||||
|
||||
Creates a new database using a structured `DatabaseConfig` object.
|
||||
|
||||
### CreateDatabaseFromMap
|
||||
|
||||
```go
|
||||
func (c *FaultInjectorClient) CreateDatabaseFromMap(
|
||||
ctx context.Context,
|
||||
clusterIndex int,
|
||||
databaseConfig map[string]interface{},
|
||||
) (*ActionResponse, error)
|
||||
```
|
||||
|
||||
Creates a new database using a flexible map configuration. Useful when you need to pass custom or dynamic configurations.
|
||||
|
||||
### DeleteDatabase
|
||||
|
||||
```go
|
||||
func (c *FaultInjectorClient) DeleteDatabase(
|
||||
ctx context.Context,
|
||||
clusterIndex int,
|
||||
bdbID int,
|
||||
) (*ActionResponse, error)
|
||||
```
|
||||
|
||||
Deletes an existing database by its ID.
|
||||
|
||||
## Testing
|
||||
|
||||
To run the database management E2E tests:
|
||||
|
||||
```bash
|
||||
# Run all database management tests
|
||||
go test -tags=e2e -v ./maintnotifications/e2e/ -run TestDatabase
|
||||
|
||||
# Run specific test
|
||||
go test -tags=e2e -v ./maintnotifications/e2e/ -run TestDatabaseLifecycle
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Database creation can take several minutes depending on the configuration
|
||||
- Always use `WaitForAction` to ensure the operation completes before proceeding
|
||||
- The `bdb_id` returned in the creation output should be used for deletion
|
||||
- Deleting a non-existent database will result in a failed action status
|
||||
- Memory sizes are specified in bytes (e.g., 268435456 = 256MB)
|
||||
- Port numbers should be unique and not conflict with existing databases
|
||||
|
||||
## Common Eviction Policies
|
||||
|
||||
- `noeviction` - Return errors when memory limit is reached
|
||||
- `allkeys-lru` - Evict any key using LRU algorithm
|
||||
- `volatile-lru` - Evict keys with TTL using LRU algorithm
|
||||
- `allkeys-random` - Evict random keys
|
||||
- `volatile-random` - Evict random keys with TTL
|
||||
- `volatile-ttl` - Evict keys with TTL, shortest TTL first
|
||||
|
||||
## Common Proxy Policies
|
||||
|
||||
- `all-master-shards` - Route to all master shards
|
||||
- `all-nodes` - Route to all nodes
|
||||
- `single-shard` - Route to a single shard
|
||||
|
||||
156
maintnotifications/e2e/README_SCENARIOS.md
Normal file
156
maintnotifications/e2e/README_SCENARIOS.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# E2E Test Scenarios for Push Notifications
|
||||
|
||||
This directory contains comprehensive end-to-end test scenarios for Redis push notifications and maintenance notifications functionality. Each scenario tests different aspects of the system under various conditions.
|
||||
|
||||
## ⚠️ **Important Note**
|
||||
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support maintenance notifications functionality.
|
||||
|
||||
## Introduction
|
||||
|
||||
To run those tests you would need a fault injector service, please review the client and feel free to implement your
|
||||
fault injector of choice. Those tests are tailored for Redis Enterprise, but can be adapted to other Redis distributions where
|
||||
a fault injector is available.
|
||||
|
||||
Once you have fault injector service up and running, you can execute the tests by running the `run-e2e-tests.sh` script.
|
||||
there are three environment variables that need to be set before running the tests:
|
||||
|
||||
- `REDIS_ENDPOINTS_CONFIG_PATH`: Path to Redis endpoints configuration
|
||||
- `FAULT_INJECTION_API_URL`: URL of the fault injector server
|
||||
- `E2E_SCENARIO_TESTS`: Set to `true` to enable scenario tests
|
||||
|
||||
## Test Scenarios Overview
|
||||
|
||||
### 1. Basic Push Notifications (`scenario_push_notifications_test.go`)
|
||||
**Original template scenario**
|
||||
- **Purpose**: Basic functionality test for Redis Enterprise push notifications
|
||||
- **Features Tested**: FAILING_OVER, FAILED_OVER, MIGRATING, MIGRATED, MOVING notifications
|
||||
- **Configuration**: Standard enterprise cluster setup
|
||||
- **Duration**: ~10 minutes
|
||||
- **Key Validations**:
|
||||
- All notification types received
|
||||
- Timeout behavior (relaxed/unrelaxed)
|
||||
- Handoff success rates
|
||||
- Connection pool management
|
||||
|
||||
### 2. Endpoint Types Scenario (`scenario_endpoint_types_test.go`)
|
||||
**Different endpoint resolution strategies**
|
||||
- **Purpose**: Test push notifications with different endpoint types
|
||||
- **Features Tested**: ExternalIP, InternalIP, InternalFQDN, ExternalFQDN endpoint types
|
||||
- **Configuration**: Standard setup with varying endpoint types
|
||||
- **Duration**: ~5 minutes
|
||||
- **Key Validations**:
|
||||
- Functionality with each endpoint type
|
||||
- Proper endpoint resolution
|
||||
- Notification delivery consistency
|
||||
- Handoff behavior per endpoint type
|
||||
|
||||
### 3. Database Management Scenario (`scenario_database_management_test.go`)
|
||||
**Dynamic database creation and deletion**
|
||||
- **Purpose**: Test database lifecycle management via fault injector
|
||||
- **Features Tested**: CREATE_DATABASE, DELETE_DATABASE endpoints
|
||||
- **Configuration**: Various database configurations (simple, with modules, clustered)
|
||||
- **Duration**: ~10 minutes
|
||||
- **Key Validations**:
|
||||
- Database creation with different configurations
|
||||
- Database creation with Redis modules (ReJSON, search, timeseries, bf)
|
||||
- Database deletion
|
||||
- Complete lifecycle (create → use → delete)
|
||||
- Configuration validation
|
||||
|
||||
See [DATABASE_MANAGEMENT.md](DATABASE_MANAGEMENT.md) for detailed documentation on database management endpoints.
|
||||
|
||||
### 4. Timeout Configurations Scenario (`scenario_timeout_configs_test.go`)
|
||||
**Various timeout strategies**
|
||||
- **Purpose**: Test different timeout configurations and their impact
|
||||
- **Features Tested**: Conservative, Aggressive, HighLatency timeouts
|
||||
- **Configuration**:
|
||||
- Conservative: 60s handoff, 20s relaxed, 5s post-handoff
|
||||
- Aggressive: 5s handoff, 3s relaxed, 1s post-handoff
|
||||
- HighLatency: 90s handoff, 30s relaxed, 10m post-handoff
|
||||
- **Duration**: ~10 minutes (3 sub-tests)
|
||||
- **Key Validations**:
|
||||
- Timeout behavior matches configuration
|
||||
- Recovery times appropriate for each strategy
|
||||
- Error rates correlate with timeout aggressiveness
|
||||
|
||||
### 5. TLS Configurations Scenario (`scenario_tls_configs_test.go`)
|
||||
**Security and encryption testing framework**
|
||||
- **Purpose**: Test push notifications with different TLS configurations
|
||||
- **Features Tested**: NoTLS, TLSInsecure, TLSSecure, TLSMinimal, TLSStrict
|
||||
- **Configuration**: Framework for testing various TLS settings (TLS config handled at connection level)
|
||||
- **Duration**: ~10 minutes (multiple sub-tests)
|
||||
- **Key Validations**:
|
||||
- Functionality with each TLS configuration
|
||||
- Performance impact of encryption
|
||||
- Certificate handling (where applicable)
|
||||
- Security compliance
|
||||
- **Note**: TLS configuration is handled at the Redis connection config level, not client options level
|
||||
|
||||
### 6. Stress Test Scenario (`scenario_stress_test.go`)
|
||||
**Extreme load and concurrent operations**
|
||||
- **Purpose**: Test system limits and behavior under extreme stress
|
||||
- **Features Tested**: Maximum concurrent operations, multiple clients
|
||||
- **Configuration**:
|
||||
- 4 clients with 150 pool size each
|
||||
- 200 max connections per client
|
||||
- 50 workers, 1000 queue size
|
||||
- Concurrent failover/migration actions
|
||||
- **Duration**: ~15 minutes
|
||||
- **Key Validations**:
|
||||
- System stability under extreme load
|
||||
- Error rates within stress limits (<20%)
|
||||
- Resource utilization and limits
|
||||
- Concurrent fault injection handling
|
||||
|
||||
|
||||
## Running the Scenarios
|
||||
|
||||
### Prerequisites
|
||||
- Set environment variable: `E2E_SCENARIO_TESTS=true`
|
||||
- Redis Enterprise cluster available
|
||||
- Fault injection service available
|
||||
- Appropriate network access and permissions
|
||||
- **Note**: Tests use standalone Redis clients only (cluster clients not supported)
|
||||
|
||||
### Individual Scenario Execution
|
||||
```bash
|
||||
# Run a specific scenario
|
||||
E2E_SCENARIO_TESTS=true go test -v ./maintnotifications/e2e -run TestEndpointTypesPushNotifications
|
||||
|
||||
# Run with timeout
|
||||
E2E_SCENARIO_TESTS=true go test -v -timeout 30m ./maintnotifications/e2e -run TestStressPushNotifications
|
||||
```
|
||||
|
||||
### All Scenarios Execution
|
||||
```bash
|
||||
./scripts/run-e2e-tests.sh
|
||||
```
|
||||
## Expected Outcomes
|
||||
|
||||
### Success Criteria
|
||||
- All notifications received and processed correctly
|
||||
- Error rates within acceptable limits for each scenario
|
||||
- No notification processing errors
|
||||
- Proper timeout behavior
|
||||
- Successful handoffs
|
||||
- Connection pool management within limits
|
||||
|
||||
### Performance Benchmarks
|
||||
- **Basic**: >1000 operations, <1% errors
|
||||
- **Stress**: >10000 operations, <20% errors
|
||||
- **Others**: Functionality over performance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
1. **Enterprise cluster not available**: Most scenarios require Redis Enterprise
|
||||
2. **Fault injector unavailable**: Some scenarios need fault injection service
|
||||
3. **Network timeouts**: Increase test timeouts for slow networks
|
||||
4. **TLS certificate issues**: Some TLS scenarios may fail without proper certs
|
||||
5. **Resource limits**: Stress scenarios may hit system limits
|
||||
|
||||
### Debug Options
|
||||
- Enable detailed logging in scenarios
|
||||
- Use `dump = true` to see full log analysis
|
||||
- Check pool statistics for connection issues
|
||||
- Monitor client resources during stress tests
|
||||
137
maintnotifications/e2e/command_runner_test.go
Normal file
137
maintnotifications/e2e/command_runner_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type CommandRunnerStats struct {
|
||||
Operations int64
|
||||
Errors int64
|
||||
TimeoutErrors int64
|
||||
ErrorsList []error
|
||||
}
|
||||
|
||||
// CommandRunner provides utilities for running commands during tests
|
||||
type CommandRunner struct {
|
||||
client redis.UniversalClient
|
||||
stopCh chan struct{}
|
||||
operationCount atomic.Int64
|
||||
errorCount atomic.Int64
|
||||
timeoutErrors atomic.Int64
|
||||
errors []error
|
||||
errorsMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewCommandRunner creates a new command runner
|
||||
func NewCommandRunner(client redis.UniversalClient) (*CommandRunner, func()) {
|
||||
stopCh := make(chan struct{})
|
||||
return &CommandRunner{
|
||||
client: client,
|
||||
stopCh: stopCh,
|
||||
errors: make([]error, 0),
|
||||
}, func() {
|
||||
stopCh <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (cr *CommandRunner) Stop() {
|
||||
select {
|
||||
case cr.stopCh <- struct{}{}:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (cr *CommandRunner) Close() {
|
||||
close(cr.stopCh)
|
||||
}
|
||||
|
||||
// FireCommandsUntilStop runs commands continuously until stop signal
|
||||
func (cr *CommandRunner) FireCommandsUntilStop(ctx context.Context) {
|
||||
fmt.Printf("[CR] Starting command runner...\n")
|
||||
defer fmt.Printf("[CR] Command runner stopped\n")
|
||||
// High frequency for timeout testing
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
counter := 0
|
||||
for {
|
||||
select {
|
||||
case <-cr.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
poolSize := cr.client.PoolStats().IdleConns
|
||||
if poolSize == 0 {
|
||||
poolSize = 1
|
||||
}
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < int(poolSize); i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("timeout-test-key-%d-%d", counter, i)
|
||||
value := fmt.Sprintf("timeout-test-value-%d-%d", counter, i)
|
||||
|
||||
// Use a short timeout context for individual operations
|
||||
opCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err := cr.client.Set(opCtx, key, value, time.Minute).Err()
|
||||
cancel()
|
||||
|
||||
cr.operationCount.Add(1)
|
||||
if err != nil {
|
||||
if err == redis.ErrClosed || strings.Contains(err.Error(), "client is closed") {
|
||||
select {
|
||||
case <-cr.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
cr.errorCount.Add(1)
|
||||
|
||||
// Check if it's a timeout error
|
||||
if isTimeoutError(err) {
|
||||
cr.timeoutErrors.Add(1)
|
||||
}
|
||||
|
||||
cr.errorsMutex.Lock()
|
||||
cr.errors = append(cr.errors, err)
|
||||
cr.errorsMutex.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
counter++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns operation statistics
|
||||
func (cr *CommandRunner) GetStats() CommandRunnerStats {
|
||||
cr.errorsMutex.Lock()
|
||||
defer cr.errorsMutex.Unlock()
|
||||
|
||||
errorList := make([]error, len(cr.errors))
|
||||
copy(errorList, cr.errors)
|
||||
|
||||
stats := CommandRunnerStats{
|
||||
Operations: cr.operationCount.Load(),
|
||||
Errors: cr.errorCount.Load(),
|
||||
TimeoutErrors: cr.timeoutErrors.Load(),
|
||||
ErrorsList: errorList,
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
1111
maintnotifications/e2e/config_parser_test.go
Normal file
1111
maintnotifications/e2e/config_parser_test.go
Normal file
File diff suppressed because it is too large
Load Diff
21
maintnotifications/e2e/doc.go
Normal file
21
maintnotifications/e2e/doc.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Package e2e provides end-to-end testing scenarios for the maintenance notifications system.
|
||||
//
|
||||
// This package contains comprehensive test scenarios that validate the maintenance notifications
|
||||
// functionality in realistic environments. The tests are designed to work with Redis Enterprise
|
||||
// clusters and require specific environment configuration.
|
||||
//
|
||||
// Environment Variables:
|
||||
// - E2E_SCENARIO_TESTS: Set to "true" to enable scenario tests
|
||||
// - REDIS_ENDPOINTS_CONFIG_PATH: Path to endpoints configuration file
|
||||
// - FAULT_INJECTION_API_URL: URL for fault injection API (optional)
|
||||
//
|
||||
// Test Scenarios:
|
||||
// - Basic Push Notifications: Core functionality testing
|
||||
// - Endpoint Types: Different endpoint resolution strategies
|
||||
// - Timeout Configurations: Various timeout strategies
|
||||
// - TLS Configurations: Different TLS setups
|
||||
// - Stress Testing: Extreme load and concurrent operations
|
||||
//
|
||||
// Note: Maintenance notifications are currently supported only in standalone Redis clients.
|
||||
// Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
|
||||
package e2e
|
||||
110
maintnotifications/e2e/examples/endpoints.json
Normal file
110
maintnotifications/e2e/examples/endpoints.json
Normal file
@@ -0,0 +1,110 @@
|
||||
{
|
||||
"standalone0": {
|
||||
"password": "foobared",
|
||||
"tls": false,
|
||||
"endpoints": [
|
||||
"redis://localhost:6379"
|
||||
]
|
||||
},
|
||||
"standalone0-tls": {
|
||||
"username": "default",
|
||||
"password": "foobared",
|
||||
"tls": true,
|
||||
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
|
||||
"endpoints": [
|
||||
"rediss://localhost:6390"
|
||||
]
|
||||
},
|
||||
"standalone0-acl": {
|
||||
"username": "acljedis",
|
||||
"password": "fizzbuzz",
|
||||
"tls": false,
|
||||
"endpoints": [
|
||||
"redis://localhost:6379"
|
||||
]
|
||||
},
|
||||
"standalone0-acl-tls": {
|
||||
"username": "acljedis",
|
||||
"password": "fizzbuzz",
|
||||
"tls": true,
|
||||
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
|
||||
"endpoints": [
|
||||
"rediss://localhost:6390"
|
||||
]
|
||||
},
|
||||
"cluster0": {
|
||||
"username": "default",
|
||||
"password": "foobared",
|
||||
"tls": false,
|
||||
"endpoints": [
|
||||
"redis://localhost:7001",
|
||||
"redis://localhost:7002",
|
||||
"redis://localhost:7003",
|
||||
"redis://localhost:7004",
|
||||
"redis://localhost:7005",
|
||||
"redis://localhost:7006"
|
||||
]
|
||||
},
|
||||
"cluster0-tls": {
|
||||
"username": "default",
|
||||
"password": "foobared",
|
||||
"tls": true,
|
||||
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
|
||||
"endpoints": [
|
||||
"rediss://localhost:7011",
|
||||
"rediss://localhost:7012",
|
||||
"rediss://localhost:7013",
|
||||
"rediss://localhost:7014",
|
||||
"rediss://localhost:7015",
|
||||
"rediss://localhost:7016"
|
||||
]
|
||||
},
|
||||
"sentinel0": {
|
||||
"username": "default",
|
||||
"password": "foobared",
|
||||
"tls": false,
|
||||
"endpoints": [
|
||||
"redis://localhost:26379",
|
||||
"redis://localhost:26380",
|
||||
"redis://localhost:26381"
|
||||
]
|
||||
},
|
||||
"modules-docker": {
|
||||
"tls": false,
|
||||
"endpoints": [
|
||||
"redis://localhost:6479"
|
||||
]
|
||||
},
|
||||
"enterprise-cluster": {
|
||||
"bdb_id": 1,
|
||||
"username": "default",
|
||||
"password": "enterprise-password",
|
||||
"tls": true,
|
||||
"raw_endpoints": [
|
||||
{
|
||||
"addr": ["10.0.0.1"],
|
||||
"addr_type": "ipv4",
|
||||
"dns_name": "redis-enterprise-cluster.example.com",
|
||||
"oss_cluster_api_preferred_endpoint_type": "internal",
|
||||
"oss_cluster_api_preferred_ip_type": "ipv4",
|
||||
"port": 12000,
|
||||
"proxy_policy": "single",
|
||||
"uid": "endpoint-1"
|
||||
},
|
||||
{
|
||||
"addr": ["10.0.0.2"],
|
||||
"addr_type": "ipv4",
|
||||
"dns_name": "redis-enterprise-cluster-2.example.com",
|
||||
"oss_cluster_api_preferred_endpoint_type": "internal",
|
||||
"oss_cluster_api_preferred_ip_type": "ipv4",
|
||||
"port": 12000,
|
||||
"proxy_policy": "single",
|
||||
"uid": "endpoint-2"
|
||||
}
|
||||
],
|
||||
"endpoints": [
|
||||
"rediss://redis-enterprise-cluster.example.com:12000",
|
||||
"rediss://redis-enterprise-cluster-2.example.com:12000"
|
||||
]
|
||||
}
|
||||
}
|
||||
644
maintnotifications/e2e/fault_injector_test.go
Normal file
644
maintnotifications/e2e/fault_injector_test.go
Normal file
@@ -0,0 +1,644 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ActionType represents the type of fault injection action
|
||||
type ActionType string
|
||||
|
||||
const (
|
||||
// Redis cluster actions
|
||||
ActionClusterFailover ActionType = "cluster_failover"
|
||||
ActionClusterReshard ActionType = "cluster_reshard"
|
||||
ActionClusterAddNode ActionType = "cluster_add_node"
|
||||
ActionClusterRemoveNode ActionType = "cluster_remove_node"
|
||||
ActionClusterMigrate ActionType = "cluster_migrate"
|
||||
|
||||
// Node-level actions
|
||||
ActionNodeRestart ActionType = "node_restart"
|
||||
ActionNodeStop ActionType = "node_stop"
|
||||
ActionNodeStart ActionType = "node_start"
|
||||
ActionNodeKill ActionType = "node_kill"
|
||||
|
||||
// Network simulation actions
|
||||
ActionNetworkPartition ActionType = "network_partition"
|
||||
ActionNetworkLatency ActionType = "network_latency"
|
||||
ActionNetworkPacketLoss ActionType = "network_packet_loss"
|
||||
ActionNetworkBandwidth ActionType = "network_bandwidth"
|
||||
ActionNetworkRestore ActionType = "network_restore"
|
||||
|
||||
// Redis configuration actions
|
||||
ActionConfigChange ActionType = "config_change"
|
||||
ActionMaintenanceMode ActionType = "maintenance_mode"
|
||||
ActionSlotMigration ActionType = "slot_migration"
|
||||
|
||||
// Sequence and complex actions
|
||||
ActionSequence ActionType = "sequence_of_actions"
|
||||
ActionExecuteCommand ActionType = "execute_command"
|
||||
|
||||
// Database management actions
|
||||
ActionDeleteDatabase ActionType = "delete_database"
|
||||
ActionCreateDatabase ActionType = "create_database"
|
||||
)
|
||||
|
||||
// ActionStatus represents the status of an action
|
||||
type ActionStatus string
|
||||
|
||||
const (
|
||||
StatusPending ActionStatus = "pending"
|
||||
StatusRunning ActionStatus = "running"
|
||||
StatusFinished ActionStatus = "finished"
|
||||
StatusFailed ActionStatus = "failed"
|
||||
StatusSuccess ActionStatus = "success"
|
||||
StatusCancelled ActionStatus = "cancelled"
|
||||
)
|
||||
|
||||
// ActionRequest represents a request to trigger an action
|
||||
type ActionRequest struct {
|
||||
Type ActionType `json:"type"`
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// ActionResponse represents the response from triggering an action
|
||||
type ActionResponse struct {
|
||||
ActionID string `json:"action_id"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ActionStatusResponse represents the status of an action
|
||||
type ActionStatusResponse struct {
|
||||
ActionID string `json:"action_id"`
|
||||
Status ActionStatus `json:"status"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
Progress float64 `json:"progress,omitempty"`
|
||||
StartTime time.Time `json:"start_time,omitempty"`
|
||||
EndTime time.Time `json:"end_time,omitempty"`
|
||||
}
|
||||
|
||||
// SequenceAction represents an action in a sequence
|
||||
type SequenceAction struct {
|
||||
Type ActionType `json:"type"`
|
||||
Parameters map[string]interface{} `json:"params,omitempty"`
|
||||
Delay time.Duration `json:"delay,omitempty"`
|
||||
}
|
||||
|
||||
// FaultInjectorClient provides programmatic control over test infrastructure
|
||||
type FaultInjectorClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewFaultInjectorClient creates a new fault injector client
|
||||
func NewFaultInjectorClient(baseURL string) *FaultInjectorClient {
|
||||
return &FaultInjectorClient{
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetBaseURL returns the base URL of the fault injector server
|
||||
func (c *FaultInjectorClient) GetBaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
// ListActions lists all available actions
|
||||
func (c *FaultInjectorClient) ListActions(ctx context.Context) ([]ActionType, error) {
|
||||
var actions []ActionType
|
||||
err := c.request(ctx, "GET", "/actions", nil, &actions)
|
||||
return actions, err
|
||||
}
|
||||
|
||||
// TriggerAction triggers a specific action
|
||||
func (c *FaultInjectorClient) TriggerAction(ctx context.Context, action ActionRequest) (*ActionResponse, error) {
|
||||
var response ActionResponse
|
||||
fmt.Printf("[FI] Triggering action: %+v\n", action)
|
||||
err := c.request(ctx, "POST", "/action", action, &response)
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func (c *FaultInjectorClient) TriggerSequence(ctx context.Context, bdbID int, actions []SequenceAction) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionSequence,
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": bdbID,
|
||||
"actions": actions,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetActionStatus gets the status of a specific action
|
||||
func (c *FaultInjectorClient) GetActionStatus(ctx context.Context, actionID string) (*ActionStatusResponse, error) {
|
||||
var status ActionStatusResponse
|
||||
err := c.request(ctx, "GET", fmt.Sprintf("/action/%s", actionID), nil, &status)
|
||||
return &status, err
|
||||
}
|
||||
|
||||
// WaitForAction waits for an action to complete
|
||||
func (c *FaultInjectorClient) WaitForAction(ctx context.Context, actionID string, options ...WaitOption) (*ActionStatusResponse, error) {
|
||||
config := &waitConfig{
|
||||
pollInterval: 1 * time.Second,
|
||||
maxWaitTime: 60 * time.Second,
|
||||
}
|
||||
|
||||
for _, opt := range options {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(config.maxWaitTime)
|
||||
ticker := time.NewTicker(config.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return nil, fmt.Errorf("timeout waiting for action %s after %v", actionID, config.maxWaitTime)
|
||||
case <-ticker.C:
|
||||
status, err := c.GetActionStatus(ctx, actionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get action status: %w", err)
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case StatusFinished, StatusSuccess, StatusFailed, StatusCancelled:
|
||||
return status, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster Management Actions
|
||||
|
||||
// TriggerClusterFailover triggers a cluster failover
|
||||
func (c *FaultInjectorClient) TriggerClusterFailover(ctx context.Context, nodeID string, force bool) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionClusterFailover,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"force": force,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// TriggerClusterReshard triggers cluster resharding
|
||||
func (c *FaultInjectorClient) TriggerClusterReshard(ctx context.Context, slots []int, sourceNode, targetNode string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionClusterReshard,
|
||||
Parameters: map[string]interface{}{
|
||||
"slots": slots,
|
||||
"source_node": sourceNode,
|
||||
"target_node": targetNode,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// TriggerSlotMigration triggers migration of specific slots
|
||||
func (c *FaultInjectorClient) TriggerSlotMigration(ctx context.Context, startSlot, endSlot int, sourceNode, targetNode string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionSlotMigration,
|
||||
Parameters: map[string]interface{}{
|
||||
"start_slot": startSlot,
|
||||
"end_slot": endSlot,
|
||||
"source_node": sourceNode,
|
||||
"target_node": targetNode,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Node Management Actions
|
||||
|
||||
// RestartNode restarts a specific Redis node
|
||||
func (c *FaultInjectorClient) RestartNode(ctx context.Context, nodeID string, graceful bool) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNodeRestart,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"graceful": graceful,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// StopNode stops a specific Redis node
|
||||
func (c *FaultInjectorClient) StopNode(ctx context.Context, nodeID string, graceful bool) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNodeStop,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"graceful": graceful,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// StartNode starts a specific Redis node
|
||||
func (c *FaultInjectorClient) StartNode(ctx context.Context, nodeID string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNodeStart,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// KillNode forcefully kills a Redis node
|
||||
func (c *FaultInjectorClient) KillNode(ctx context.Context, nodeID string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNodeKill,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Network Simulation Actions
|
||||
|
||||
// SimulateNetworkPartition simulates a network partition
|
||||
func (c *FaultInjectorClient) SimulateNetworkPartition(ctx context.Context, nodes []string, duration time.Duration) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNetworkPartition,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
"duration": duration.String(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// SimulateNetworkLatency adds network latency
|
||||
func (c *FaultInjectorClient) SimulateNetworkLatency(ctx context.Context, nodes []string, latency time.Duration, jitter time.Duration) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNetworkLatency,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
"latency": latency.String(),
|
||||
"jitter": jitter.String(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// SimulatePacketLoss simulates packet loss
|
||||
func (c *FaultInjectorClient) SimulatePacketLoss(ctx context.Context, nodes []string, lossPercent float64) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNetworkPacketLoss,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
"loss_percent": lossPercent,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// LimitBandwidth limits network bandwidth
|
||||
func (c *FaultInjectorClient) LimitBandwidth(ctx context.Context, nodes []string, bandwidth string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNetworkBandwidth,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
"bandwidth": bandwidth,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RestoreNetwork restores normal network conditions
|
||||
func (c *FaultInjectorClient) RestoreNetwork(ctx context.Context, nodes []string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionNetworkRestore,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Configuration Actions
|
||||
|
||||
// ChangeConfig changes Redis configuration
|
||||
func (c *FaultInjectorClient) ChangeConfig(ctx context.Context, nodeID string, config map[string]string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionConfigChange,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"config": config,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// EnableMaintenanceMode enables maintenance mode
|
||||
func (c *FaultInjectorClient) EnableMaintenanceMode(ctx context.Context, nodeID string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionMaintenanceMode,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"enabled": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// DisableMaintenanceMode disables maintenance mode
|
||||
func (c *FaultInjectorClient) DisableMaintenanceMode(ctx context.Context, nodeID string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionMaintenanceMode,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"enabled": false,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Database Management Actions
|
||||
|
||||
// EnvDatabaseConfig represents the configuration for creating a database
|
||||
type DatabaseConfig struct {
|
||||
Name string `json:"name"`
|
||||
Port int `json:"port"`
|
||||
MemorySize int64 `json:"memory_size"`
|
||||
Replication bool `json:"replication"`
|
||||
EvictionPolicy string `json:"eviction_policy"`
|
||||
Sharding bool `json:"sharding"`
|
||||
AutoUpgrade bool `json:"auto_upgrade"`
|
||||
ShardsCount int `json:"shards_count"`
|
||||
ModuleList []DatabaseModule `json:"module_list,omitempty"`
|
||||
OSSCluster bool `json:"oss_cluster"`
|
||||
OSSClusterAPIPreferredIPType string `json:"oss_cluster_api_preferred_ip_type,omitempty"`
|
||||
ProxyPolicy string `json:"proxy_policy,omitempty"`
|
||||
ShardsPlacement string `json:"shards_placement,omitempty"`
|
||||
ShardKeyRegex []ShardKeyRegexPattern `json:"shard_key_regex,omitempty"`
|
||||
}
|
||||
|
||||
// DatabaseModule represents a Redis module configuration
|
||||
type DatabaseModule struct {
|
||||
ModuleArgs string `json:"module_args"`
|
||||
ModuleName string `json:"module_name"`
|
||||
}
|
||||
|
||||
// ShardKeyRegexPattern represents a shard key regex pattern
|
||||
type ShardKeyRegexPattern struct {
|
||||
Regex string `json:"regex"`
|
||||
}
|
||||
|
||||
// DeleteDatabase deletes a database
|
||||
// Parameters:
|
||||
// - clusterIndex: The index of the cluster
|
||||
// - bdbID: The database ID to delete
|
||||
func (c *FaultInjectorClient) DeleteDatabase(ctx context.Context, clusterIndex int, bdbID int) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionDeleteDatabase,
|
||||
Parameters: map[string]interface{}{
|
||||
"cluster_index": clusterIndex,
|
||||
"bdb_id": bdbID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
// Parameters:
|
||||
// - clusterIndex: The index of the cluster
|
||||
// - databaseConfig: The database configuration
|
||||
func (c *FaultInjectorClient) CreateDatabase(ctx context.Context, clusterIndex int, databaseConfig DatabaseConfig) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionCreateDatabase,
|
||||
Parameters: map[string]interface{}{
|
||||
"cluster_index": clusterIndex,
|
||||
"database_config": databaseConfig,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// CreateDatabaseFromMap creates a new database using a map for configuration
|
||||
// This is useful when you want to pass a raw configuration map
|
||||
// Parameters:
|
||||
// - clusterIndex: The index of the cluster
|
||||
// - databaseConfig: The database configuration as a map
|
||||
func (c *FaultInjectorClient) CreateDatabaseFromMap(ctx context.Context, clusterIndex int, databaseConfig map[string]interface{}) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionCreateDatabase,
|
||||
Parameters: map[string]interface{}{
|
||||
"cluster_index": clusterIndex,
|
||||
"database_config": databaseConfig,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Complex Actions
|
||||
|
||||
// ExecuteSequence executes a sequence of actions
|
||||
func (c *FaultInjectorClient) ExecuteSequence(ctx context.Context, actions []SequenceAction) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionSequence,
|
||||
Parameters: map[string]interface{}{
|
||||
"actions": actions,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a custom command
|
||||
func (c *FaultInjectorClient) ExecuteCommand(ctx context.Context, nodeID, command string) (*ActionResponse, error) {
|
||||
return c.TriggerAction(ctx, ActionRequest{
|
||||
Type: ActionExecuteCommand,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"command": command,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Convenience Methods
|
||||
|
||||
// SimulateClusterUpgrade simulates a complete cluster upgrade scenario
|
||||
func (c *FaultInjectorClient) SimulateClusterUpgrade(ctx context.Context, nodes []string) (*ActionResponse, error) {
|
||||
actions := make([]SequenceAction, 0, len(nodes)*2)
|
||||
|
||||
// Rolling restart of all nodes
|
||||
for i, nodeID := range nodes {
|
||||
actions = append(actions, SequenceAction{
|
||||
Type: ActionNodeRestart,
|
||||
Parameters: map[string]interface{}{
|
||||
"node_id": nodeID,
|
||||
"graceful": true,
|
||||
},
|
||||
Delay: time.Duration(i*10) * time.Second, // Stagger restarts
|
||||
})
|
||||
}
|
||||
|
||||
return c.ExecuteSequence(ctx, actions)
|
||||
}
|
||||
|
||||
// SimulateNetworkIssues simulates various network issues
|
||||
func (c *FaultInjectorClient) SimulateNetworkIssues(ctx context.Context, nodes []string) (*ActionResponse, error) {
|
||||
actions := []SequenceAction{
|
||||
{
|
||||
Type: ActionNetworkLatency,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
"latency": "100ms",
|
||||
"jitter": "20ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: ActionNetworkPacketLoss,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
"loss_percent": 2.0,
|
||||
},
|
||||
Delay: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
Type: ActionNetworkRestore,
|
||||
Parameters: map[string]interface{}{
|
||||
"nodes": nodes,
|
||||
},
|
||||
Delay: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
return c.ExecuteSequence(ctx, actions)
|
||||
}
|
||||
|
||||
// Helper types and functions
|
||||
|
||||
type waitConfig struct {
|
||||
pollInterval time.Duration
|
||||
maxWaitTime time.Duration
|
||||
}
|
||||
|
||||
type WaitOption func(*waitConfig)
|
||||
|
||||
// WithPollInterval sets the polling interval for waiting
|
||||
func WithPollInterval(interval time.Duration) WaitOption {
|
||||
return func(c *waitConfig) {
|
||||
c.pollInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxWaitTime sets the maximum wait time
|
||||
func WithMaxWaitTime(maxWait time.Duration) WaitOption {
|
||||
return func(c *waitConfig) {
|
||||
c.maxWaitTime = maxWait
|
||||
}
|
||||
}
|
||||
|
||||
// Internal HTTP request method
|
||||
func (c *FaultInjectorClient) request(ctx context.Context, method, path string, body interface{}, result interface{}) error {
|
||||
url := c.baseURL + path
|
||||
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewReader(jsonData)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
if err := json.Unmarshal(respBody, result); err != nil {
|
||||
// happens when the API changes and the response structure changes
|
||||
// sometimes the output of the action status is map, sometimes it is json.
|
||||
// since we don't have a proper response structure we are going to handle it here
|
||||
if result, ok := result.(*ActionStatusResponse); ok {
|
||||
mapResult := map[string]interface{}{}
|
||||
err = json.Unmarshal(respBody, &mapResult)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to unmarshal response:", string(respBody))
|
||||
panic(err)
|
||||
}
|
||||
result.Error = mapResult["error"]
|
||||
result.Output = map[string]interface{}{"result": mapResult["output"]}
|
||||
if status, ok := mapResult["status"].(string); ok {
|
||||
result.Status = ActionStatus(status)
|
||||
}
|
||||
if result.Status == StatusSuccess || result.Status == StatusFailed || result.Status == StatusCancelled {
|
||||
result.EndTime = time.Now()
|
||||
}
|
||||
if progress, ok := mapResult["progress"].(float64); ok {
|
||||
result.Progress = progress
|
||||
}
|
||||
if actionID, ok := mapResult["action_id"].(string); ok {
|
||||
result.ActionID = actionID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
fmt.Println("Failed to unmarshal response:", string(respBody))
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Utility functions for common scenarios
|
||||
|
||||
// GetClusterNodes returns a list of cluster node IDs
|
||||
func GetClusterNodes() []string {
|
||||
// TODO Implement
|
||||
// This would typically be configured via environment or discovery
|
||||
return []string{"node-1", "node-2", "node-3", "node-4", "node-5", "node-6"}
|
||||
}
|
||||
|
||||
// GetMasterNodes returns a list of master node IDs
|
||||
func GetMasterNodes() []string {
|
||||
// TODO Implement
|
||||
return []string{"node-1", "node-2", "node-3"}
|
||||
}
|
||||
|
||||
// GetSlaveNodes returns a list of slave node IDs
|
||||
func GetSlaveNodes() []string {
|
||||
// TODO Implement
|
||||
return []string{"node-4", "node-5", "node-6"}
|
||||
}
|
||||
|
||||
// ParseNodeID extracts node ID from various formats
|
||||
func ParseNodeID(nodeAddr string) string {
|
||||
// Extract node ID from address like "redis-node-1:7001" -> "node-1"
|
||||
parts := strings.Split(nodeAddr, ":")
|
||||
if len(parts) > 0 {
|
||||
addr := parts[0]
|
||||
if strings.Contains(addr, "redis-") {
|
||||
return strings.TrimPrefix(addr, "redis-")
|
||||
}
|
||||
return addr
|
||||
}
|
||||
return nodeAddr
|
||||
}
|
||||
|
||||
// FormatSlotRange formats a slot range for Redis commands
|
||||
func FormatSlotRange(start, end int) string {
|
||||
if start == end {
|
||||
return strconv.Itoa(start)
|
||||
}
|
||||
return fmt.Sprintf("%d-%d", start, end)
|
||||
}
|
||||
434
maintnotifications/e2e/logcollector_test.go
Normal file
434
maintnotifications/e2e/logcollector_test.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// logs is a slice of strings that provides additional functionality
|
||||
// for filtering and analysis
|
||||
type logs []string
|
||||
|
||||
func (l logs) Contains(searchString string) bool {
|
||||
for _, log := range l {
|
||||
if log == searchString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (l logs) GetCount() int {
|
||||
return len(l)
|
||||
}
|
||||
|
||||
func (l logs) GetCountThatContain(searchString string) int {
|
||||
count := 0
|
||||
for _, log := range l {
|
||||
if strings.Contains(log, searchString) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (l logs) GetLogsFiltered(filter func(string) bool) []string {
|
||||
filteredLogs := make([]string, 0, len(l))
|
||||
for _, log := range l {
|
||||
if filter(log) {
|
||||
filteredLogs = append(filteredLogs, log)
|
||||
}
|
||||
}
|
||||
return filteredLogs
|
||||
}
|
||||
|
||||
func (l logs) GetTimedOutLogs() logs {
|
||||
return l.GetLogsFiltered(isTimeout)
|
||||
}
|
||||
|
||||
func (l logs) GetLogsPerConn(connID uint64) logs {
|
||||
return l.GetLogsFiltered(func(log string) bool {
|
||||
return strings.Contains(log, fmt.Sprintf("conn[%d]", connID))
|
||||
})
|
||||
}
|
||||
|
||||
func (l logs) GetAnalysis() *LogAnalisis {
|
||||
return NewLogAnalysis(l)
|
||||
}
|
||||
|
||||
// TestLogCollector is a simple logger that captures logs for analysis
|
||||
// It is thread safe and can be used to capture logs from multiple clients
|
||||
// It uses type logs to provide additional functionality like filtering
|
||||
// and analysis
|
||||
type TestLogCollector struct {
|
||||
l logs
|
||||
doPrint bool
|
||||
matchFuncs []*MatchFunc
|
||||
matchFuncsMutex sync.Mutex
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) DontPrint() {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
tlc.doPrint = false
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) DoPrint() {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
tlc.l = make([]string, 0)
|
||||
tlc.doPrint = true
|
||||
}
|
||||
|
||||
// MatchFunc is a slice of functions that check the logs for a specific condition
|
||||
// use in WaitForLogMatchFunc
|
||||
type MatchFunc struct {
|
||||
completed atomic.Bool
|
||||
F func(lstring string) bool
|
||||
matches []string
|
||||
found chan struct{} // channel to notify when match is found, will be closed
|
||||
done func()
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) Printf(_ context.Context, format string, v ...interface{}) {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
lstr := fmt.Sprintf(format, v...)
|
||||
if len(tlc.matchFuncs) > 0 {
|
||||
go func(lstr string) {
|
||||
for _, matchFunc := range tlc.matchFuncs {
|
||||
if matchFunc.F(lstr) {
|
||||
matchFunc.matches = append(matchFunc.matches, lstr)
|
||||
matchFunc.done()
|
||||
return
|
||||
}
|
||||
}
|
||||
}(lstr)
|
||||
}
|
||||
if tlc.doPrint {
|
||||
fmt.Println(lstr)
|
||||
}
|
||||
tlc.l = append(tlc.l, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) WaitForLogContaining(searchString string, timeout time.Duration) bool {
|
||||
timeoutCh := time.After(timeout)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
return false
|
||||
case <-ticker.C:
|
||||
if tlc.Contains(searchString) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) MatchOrWaitForLogMatchFunc(mf func(string) bool, timeout time.Duration) (string, bool) {
|
||||
if logs := tlc.GetLogsFiltered(mf); len(logs) > 0 {
|
||||
return logs[0], true
|
||||
}
|
||||
return tlc.WaitForLogMatchFunc(mf, timeout)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) WaitForLogMatchFunc(mf func(string) bool, timeout time.Duration) (string, bool) {
|
||||
matchFunc := &MatchFunc{
|
||||
completed: atomic.Bool{},
|
||||
F: mf,
|
||||
found: make(chan struct{}),
|
||||
matches: make([]string, 0),
|
||||
}
|
||||
matchFunc.done = func() {
|
||||
if !matchFunc.completed.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
close(matchFunc.found)
|
||||
tlc.matchFuncsMutex.Lock()
|
||||
defer tlc.matchFuncsMutex.Unlock()
|
||||
for i, mf := range tlc.matchFuncs {
|
||||
if mf == matchFunc {
|
||||
tlc.matchFuncs = append(tlc.matchFuncs[:i], tlc.matchFuncs[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tlc.matchFuncsMutex.Lock()
|
||||
tlc.matchFuncs = append(tlc.matchFuncs, matchFunc)
|
||||
tlc.matchFuncsMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-matchFunc.found:
|
||||
return matchFunc.matches[0], true
|
||||
case <-time.After(timeout):
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetLogs() logs {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
return tlc.l
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) DumpLogs() {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
fmt.Println("Dumping logs:")
|
||||
fmt.Println("===================================================")
|
||||
for _, log := range tlc.l {
|
||||
fmt.Println(log)
|
||||
}
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) ClearLogs() {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
tlc.l = make([]string, 0)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) Contains(searchString string) bool {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
return tlc.l.Contains(searchString)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) MatchContainsAll(searchStrings []string) []string {
|
||||
// match a log that contains all
|
||||
return tlc.GetLogsFiltered(func(log string) bool {
|
||||
for _, searchString := range searchStrings {
|
||||
if !strings.Contains(log, searchString) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetLogCount() int {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
return tlc.l.GetCount()
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetLogCountThatContain(searchString string) int {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
return tlc.l.GetCountThatContain(searchString)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetLogsFiltered(filter func(string) bool) logs {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
return tlc.l.GetLogsFiltered(filter)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetTimedOutLogs() []string {
|
||||
return tlc.GetLogsFiltered(isTimeout)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetLogsPerConn(connID uint64) logs {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
return tlc.l.GetLogsPerConn(connID)
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetAnalysisForConn(connID uint64) *LogAnalisis {
|
||||
return NewLogAnalysis(tlc.GetLogsPerConn(connID))
|
||||
}
|
||||
|
||||
func NewTestLogCollector() *TestLogCollector {
|
||||
return &TestLogCollector{
|
||||
l: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) GetAnalysis() *LogAnalisis {
|
||||
return NewLogAnalysis(tlc.GetLogs())
|
||||
}
|
||||
|
||||
func (tlc *TestLogCollector) Clear() {
|
||||
tlc.mu.Lock()
|
||||
defer tlc.mu.Unlock()
|
||||
tlc.matchFuncs = make([]*MatchFunc, 0)
|
||||
tlc.l = make([]string, 0)
|
||||
}
|
||||
|
||||
// LogAnalisis provides analysis of logs captured by TestLogCollector
|
||||
type LogAnalisis struct {
|
||||
logs []string
|
||||
TimeoutErrorsCount int64
|
||||
RelaxedTimeoutCount int64
|
||||
RelaxedPostHandoffCount int64
|
||||
UnrelaxedTimeoutCount int64
|
||||
UnrelaxedAfterMoving int64
|
||||
ConnectionCount int64
|
||||
connLogs map[uint64][]string
|
||||
connIds map[uint64]bool
|
||||
|
||||
TotalNotifications int64
|
||||
MovingCount int64
|
||||
MigratingCount int64
|
||||
MigratedCount int64
|
||||
FailingOverCount int64
|
||||
FailedOverCount int64
|
||||
UnexpectedCount int64
|
||||
|
||||
TotalHandoffCount int64
|
||||
FailedHandoffCount int64
|
||||
SucceededHandoffCount int64
|
||||
TotalHandoffRetries int64
|
||||
TotalHandoffToCurrentEndpoint int64
|
||||
}
|
||||
|
||||
func NewLogAnalysis(logs []string) *LogAnalisis {
|
||||
la := &LogAnalisis{
|
||||
logs: logs,
|
||||
connLogs: make(map[uint64][]string),
|
||||
connIds: make(map[uint64]bool),
|
||||
}
|
||||
la.Analyze()
|
||||
return la
|
||||
}
|
||||
|
||||
func (la *LogAnalisis) Analyze() {
|
||||
hasMoving := false
|
||||
for _, log := range la.logs {
|
||||
if isTimeout(log) {
|
||||
la.TimeoutErrorsCount++
|
||||
}
|
||||
if strings.Contains(log, "MOVING") {
|
||||
hasMoving = true
|
||||
}
|
||||
if strings.Contains(log, logs2.RelaxedTimeoutDueToNotificationMessage) {
|
||||
la.RelaxedTimeoutCount++
|
||||
}
|
||||
if strings.Contains(log, logs2.ApplyingRelaxedTimeoutDueToPostHandoffMessage) {
|
||||
la.RelaxedTimeoutCount++
|
||||
la.RelaxedPostHandoffCount++
|
||||
}
|
||||
if strings.Contains(log, logs2.UnrelaxedTimeoutMessage) {
|
||||
la.UnrelaxedTimeoutCount++
|
||||
}
|
||||
if strings.Contains(log, logs2.UnrelaxedTimeoutAfterDeadlineMessage) {
|
||||
if hasMoving {
|
||||
la.UnrelaxedAfterMoving++
|
||||
} else {
|
||||
fmt.Printf("Unrelaxed after deadline but no MOVING: %s\n", log)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(log, logs2.ProcessingNotificationMessage) {
|
||||
la.TotalNotifications++
|
||||
|
||||
switch {
|
||||
case notificationType(log, "MOVING"):
|
||||
la.MovingCount++
|
||||
case notificationType(log, "MIGRATING"):
|
||||
la.MigratingCount++
|
||||
case notificationType(log, "MIGRATED"):
|
||||
la.MigratedCount++
|
||||
case notificationType(log, "FAILING_OVER"):
|
||||
la.FailingOverCount++
|
||||
case notificationType(log, "FAILED_OVER"):
|
||||
la.FailedOverCount++
|
||||
default:
|
||||
fmt.Printf("[ERROR] Unexpected notification: %s\n", log)
|
||||
la.UnexpectedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(log, "conn[") {
|
||||
connID := extractConnID(log)
|
||||
if _, ok := la.connIds[connID]; !ok {
|
||||
la.connIds[connID] = true
|
||||
la.ConnectionCount++
|
||||
}
|
||||
la.connLogs[connID] = append(la.connLogs[connID], log)
|
||||
}
|
||||
|
||||
if strings.Contains(log, logs2.SchedulingHandoffToCurrentEndpointMessage) {
|
||||
la.TotalHandoffToCurrentEndpoint++
|
||||
}
|
||||
|
||||
if strings.Contains(log, logs2.HandoffSuccessMessage) {
|
||||
la.SucceededHandoffCount++
|
||||
}
|
||||
if strings.Contains(log, logs2.HandoffFailedMessage) {
|
||||
la.FailedHandoffCount++
|
||||
}
|
||||
if strings.Contains(log, logs2.HandoffStartedMessage) {
|
||||
la.TotalHandoffCount++
|
||||
}
|
||||
if strings.Contains(log, logs2.HandoffRetryAttemptMessage) {
|
||||
la.TotalHandoffRetries++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (la *LogAnalisis) Print(t *testing.T) {
|
||||
t.Logf("Log Analysis results for %d logs and %d connections:", len(la.logs), len(la.connIds))
|
||||
t.Logf("Connection Count: %d", la.ConnectionCount)
|
||||
t.Logf("-------------")
|
||||
t.Logf("-Timeout Analysis-")
|
||||
t.Logf("-------------")
|
||||
t.Logf("Timeout Errors: %d", la.TimeoutErrorsCount)
|
||||
t.Logf("Relaxed Timeout Count: %d", la.RelaxedTimeoutCount)
|
||||
t.Logf(" - Relaxed Timeout After Post-Handoff: %d", la.RelaxedPostHandoffCount)
|
||||
t.Logf("Unrelaxed Timeout Count: %d", la.UnrelaxedTimeoutCount)
|
||||
t.Logf(" - Unrelaxed Timeout After Moving: %d", la.UnrelaxedAfterMoving)
|
||||
t.Logf("-------------")
|
||||
t.Logf("-Handoff Analysis-")
|
||||
t.Logf("-------------")
|
||||
t.Logf("Total Handoffs: %d", la.TotalHandoffCount)
|
||||
t.Logf(" - Succeeded: %d", la.SucceededHandoffCount)
|
||||
t.Logf(" - Failed: %d", la.FailedHandoffCount)
|
||||
t.Logf(" - Retries: %d", la.TotalHandoffRetries)
|
||||
t.Logf(" - Handoffs to current endpoint: %d", la.TotalHandoffToCurrentEndpoint)
|
||||
t.Logf("-------------")
|
||||
t.Logf("-Notification Analysis-")
|
||||
t.Logf("-------------")
|
||||
t.Logf("Total Notifications: %d", la.TotalNotifications)
|
||||
t.Logf(" - MOVING: %d", la.MovingCount)
|
||||
t.Logf(" - MIGRATING: %d", la.MigratingCount)
|
||||
t.Logf(" - MIGRATED: %d", la.MigratedCount)
|
||||
t.Logf(" - FAILING_OVER: %d", la.FailingOverCount)
|
||||
t.Logf(" - FAILED_OVER: %d", la.FailedOverCount)
|
||||
t.Logf(" - Unexpected: %d", la.UnexpectedCount)
|
||||
t.Logf("-------------")
|
||||
t.Logf("Log Analysis completed successfully")
|
||||
}
|
||||
|
||||
func extractConnID(log string) uint64 {
|
||||
logParts := strings.Split(log, "conn[")
|
||||
if len(logParts) < 2 {
|
||||
return 0
|
||||
}
|
||||
connIDStr := strings.Split(logParts[1], "]")[0]
|
||||
connID, err := strconv.ParseUint(connIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return connID
|
||||
}
|
||||
|
||||
func notificationType(log string, nt string) bool {
|
||||
return strings.Contains(log, nt)
|
||||
}
|
||||
func connID(log string, connID uint64) bool {
|
||||
return strings.Contains(log, fmt.Sprintf("conn[%d]", connID))
|
||||
}
|
||||
func seqID(log string, seqID int64) bool {
|
||||
return strings.Contains(log, fmt.Sprintf("seqID[%d]", seqID))
|
||||
}
|
||||
39
maintnotifications/e2e/main_test.go
Normal file
39
maintnotifications/e2e/main_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
// Global log collector
|
||||
var logCollector *TestLogCollector
|
||||
|
||||
// Global fault injector client
|
||||
var faultInjector *FaultInjectorClient
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var err error
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
log.Println("Skipping scenario tests, E2E_SCENARIO_TESTS is not set")
|
||||
return
|
||||
}
|
||||
|
||||
faultInjector, err = CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
panic("Failed to create fault injector: " + err.Error())
|
||||
}
|
||||
// use log collector to capture logs from redis clients
|
||||
logCollector = NewTestLogCollector()
|
||||
redis.SetLogger(logCollector)
|
||||
redis.SetLogLevel(logging.LogLevelDebug)
|
||||
|
||||
logCollector.Clear()
|
||||
defer logCollector.Clear()
|
||||
log.Println("Running scenario tests...")
|
||||
status := m.Run()
|
||||
os.Exit(status)
|
||||
}
|
||||
435
maintnotifications/e2e/notiftracker_test.go
Normal file
435
maintnotifications/e2e/notiftracker_test.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// DiagnosticsEvent represents a notification event
|
||||
// it may be a push notification or an error when processing
|
||||
// push notifications
|
||||
type DiagnosticsEvent struct {
|
||||
// is this pre or post hook
|
||||
Type string `json:"type"`
|
||||
ConnID uint64 `json:"connID"`
|
||||
SeqID int64 `json:"seqID"`
|
||||
|
||||
Error error `json:"error"`
|
||||
|
||||
Pre bool `json:"pre"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
}
|
||||
|
||||
// TrackingNotificationsHook is a notification hook that tracks notifications
|
||||
type TrackingNotificationsHook struct {
|
||||
// unique connection count
|
||||
connectionCount atomic.Int64
|
||||
|
||||
// timeouts
|
||||
relaxedTimeoutCount atomic.Int64
|
||||
unrelaxedTimeoutCount atomic.Int64
|
||||
|
||||
notificationProcessingErrors atomic.Int64
|
||||
|
||||
// notification types
|
||||
totalNotifications atomic.Int64
|
||||
migratingCount atomic.Int64
|
||||
migratedCount atomic.Int64
|
||||
failingOverCount atomic.Int64
|
||||
failedOverCount atomic.Int64
|
||||
movingCount atomic.Int64
|
||||
unexpectedNotificationCount atomic.Int64
|
||||
|
||||
diagnosticsLog []DiagnosticsEvent
|
||||
connIds map[uint64]bool
|
||||
connLogs map[uint64][]DiagnosticsEvent
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTrackingNotificationsHook creates a new notification hook with counters
|
||||
func NewTrackingNotificationsHook() *TrackingNotificationsHook {
|
||||
return &TrackingNotificationsHook{
|
||||
diagnosticsLog: make([]DiagnosticsEvent, 0),
|
||||
connIds: make(map[uint64]bool),
|
||||
connLogs: make(map[uint64][]DiagnosticsEvent),
|
||||
}
|
||||
}
|
||||
|
||||
// it is not reusable, but just to keep it consistent
|
||||
// with the log collector
|
||||
func (tnh *TrackingNotificationsHook) Clear() {
|
||||
tnh.mutex.Lock()
|
||||
defer tnh.mutex.Unlock()
|
||||
tnh.diagnosticsLog = make([]DiagnosticsEvent, 0)
|
||||
tnh.connIds = make(map[uint64]bool)
|
||||
tnh.connLogs = make(map[uint64][]DiagnosticsEvent)
|
||||
tnh.relaxedTimeoutCount.Store(0)
|
||||
tnh.unrelaxedTimeoutCount.Store(0)
|
||||
tnh.notificationProcessingErrors.Store(0)
|
||||
tnh.totalNotifications.Store(0)
|
||||
tnh.migratingCount.Store(0)
|
||||
tnh.migratedCount.Store(0)
|
||||
tnh.failingOverCount.Store(0)
|
||||
}
|
||||
// wait for notification in prehook
|
||||
func (tnh *TrackingNotificationsHook) FindOrWaitForNotification(notificationType string, timeout time.Duration) (notification []interface{}, found bool) {
|
||||
if notification, found := tnh.FindNotification(notificationType); found {
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// wait for notification
|
||||
timeoutCh := time.After(timeout)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
return nil, false
|
||||
case <-ticker.C:
|
||||
if notification, found := tnh.FindNotification(notificationType); found {
|
||||
return notification, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) FindNotification(notificationType string) (notification []interface{}, found bool) {
|
||||
tnh.mutex.RLock()
|
||||
defer tnh.mutex.RUnlock()
|
||||
for _, event := range tnh.diagnosticsLog {
|
||||
if event.Type == notificationType {
|
||||
return event.Details["notification"].([]interface{}), true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// PreHook captures timeout-related events before processing
|
||||
func (tnh *TrackingNotificationsHook) PreHook(_ context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
tnh.increaseNotificationCount(notificationType)
|
||||
tnh.storeDiagnosticsEvent(notificationType, notification, notificationCtx)
|
||||
tnh.increaseRelaxedTimeoutCount(notificationType)
|
||||
return notification, true
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) getConnID(notificationCtx push.NotificationHandlerContext) uint64 {
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
return conn.GetID()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) getSeqID(notification []interface{}) int64 {
|
||||
seqID, ok := notification[1].(int64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return seqID
|
||||
}
|
||||
|
||||
// PostHook captures the result after processing push notification
|
||||
func (tnh *TrackingNotificationsHook) PostHook(_ context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, err error) {
|
||||
if err != nil {
|
||||
event := DiagnosticsEvent{
|
||||
Type: notificationType + "_ERROR",
|
||||
ConnID: tnh.getConnID(notificationCtx),
|
||||
SeqID: tnh.getSeqID(notification),
|
||||
Error: err,
|
||||
Timestamp: time.Now(),
|
||||
Details: map[string]interface{}{
|
||||
"notification": notification,
|
||||
"context": "post-hook",
|
||||
},
|
||||
}
|
||||
|
||||
tnh.notificationProcessingErrors.Add(1)
|
||||
tnh.mutex.Lock()
|
||||
tnh.diagnosticsLog = append(tnh.diagnosticsLog, event)
|
||||
tnh.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) storeDiagnosticsEvent(notificationType string, notification []interface{}, notificationCtx push.NotificationHandlerContext) {
|
||||
connID := tnh.getConnID(notificationCtx)
|
||||
event := DiagnosticsEvent{
|
||||
Type: notificationType,
|
||||
ConnID: connID,
|
||||
SeqID: tnh.getSeqID(notification),
|
||||
Pre: true,
|
||||
Timestamp: time.Now(),
|
||||
Details: map[string]interface{}{
|
||||
"notification": notification,
|
||||
"context": "pre-hook",
|
||||
},
|
||||
}
|
||||
|
||||
tnh.mutex.Lock()
|
||||
if v, ok := tnh.connIds[connID]; !ok || !v {
|
||||
tnh.connIds[connID] = true
|
||||
tnh.connectionCount.Add(1)
|
||||
}
|
||||
tnh.connLogs[connID] = append(tnh.connLogs[connID], event)
|
||||
tnh.diagnosticsLog = append(tnh.diagnosticsLog, event)
|
||||
tnh.mutex.Unlock()
|
||||
}
|
||||
|
||||
// GetRelaxedTimeoutCount returns the count of relaxed timeout events
|
||||
func (tnh *TrackingNotificationsHook) GetRelaxedTimeoutCount() int64 {
|
||||
return tnh.relaxedTimeoutCount.Load()
|
||||
}
|
||||
|
||||
// GetUnrelaxedTimeoutCount returns the count of unrelaxed timeout events
|
||||
func (tnh *TrackingNotificationsHook) GetUnrelaxedTimeoutCount() int64 {
|
||||
return tnh.unrelaxedTimeoutCount.Load()
|
||||
}
|
||||
|
||||
// GetNotificationProcessingErrors returns the count of timeout errors
|
||||
func (tnh *TrackingNotificationsHook) GetNotificationProcessingErrors() int64 {
|
||||
return tnh.notificationProcessingErrors.Load()
|
||||
}
|
||||
|
||||
// GetTotalNotifications returns the total number of notifications processed
|
||||
func (tnh *TrackingNotificationsHook) GetTotalNotifications() int64 {
|
||||
return tnh.totalNotifications.Load()
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the current connection count
|
||||
func (tnh *TrackingNotificationsHook) GetConnectionCount() int64 {
|
||||
return tnh.connectionCount.Load()
|
||||
}
|
||||
|
||||
// GetMovingCount returns the count of MOVING notifications
|
||||
func (tnh *TrackingNotificationsHook) GetMovingCount() int64 {
|
||||
return tnh.movingCount.Load()
|
||||
}
|
||||
|
||||
// GetDiagnosticsLog returns a copy of the diagnostics log
|
||||
func (tnh *TrackingNotificationsHook) GetDiagnosticsLog() []DiagnosticsEvent {
|
||||
tnh.mutex.RLock()
|
||||
defer tnh.mutex.RUnlock()
|
||||
|
||||
logCopy := make([]DiagnosticsEvent, len(tnh.diagnosticsLog))
|
||||
copy(logCopy, tnh.diagnosticsLog)
|
||||
return logCopy
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) increaseNotificationCount(notificationType string) {
|
||||
tnh.totalNotifications.Add(1)
|
||||
switch notificationType {
|
||||
case "MOVING":
|
||||
tnh.movingCount.Add(1)
|
||||
case "MIGRATING":
|
||||
tnh.migratingCount.Add(1)
|
||||
case "MIGRATED":
|
||||
tnh.migratedCount.Add(1)
|
||||
case "FAILING_OVER":
|
||||
tnh.failingOverCount.Add(1)
|
||||
case "FAILED_OVER":
|
||||
tnh.failedOverCount.Add(1)
|
||||
default:
|
||||
tnh.unexpectedNotificationCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) increaseRelaxedTimeoutCount(notificationType string) {
|
||||
switch notificationType {
|
||||
case "MIGRATING", "FAILING_OVER":
|
||||
tnh.relaxedTimeoutCount.Add(1)
|
||||
case "MIGRATED", "FAILED_OVER":
|
||||
tnh.unrelaxedTimeoutCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// setupNotificationHook sets up tracking for both regular and cluster clients with notification hooks
|
||||
func setupNotificationHook(client redis.UniversalClient, hook maintnotifications.NotificationHook) {
|
||||
if clusterClient, ok := client.(*redis.ClusterClient); ok {
|
||||
setupClusterClientNotificationHook(clusterClient, hook)
|
||||
} else if regularClient, ok := client.(*redis.Client); ok {
|
||||
setupRegularClientNotificationHook(regularClient, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// setupNotificationHooks sets up tracking for both regular and cluster clients with notification hooks
|
||||
func setupNotificationHooks(client redis.UniversalClient, hooks ...maintnotifications.NotificationHook) {
|
||||
for _, hook := range hooks {
|
||||
setupNotificationHook(client, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// setupRegularClientNotificationHook sets up notification hook for regular clients
|
||||
func setupRegularClientNotificationHook(client *redis.Client, hook maintnotifications.NotificationHook) {
|
||||
maintnotificationsManager := client.GetMaintNotificationsManager()
|
||||
if maintnotificationsManager != nil {
|
||||
maintnotificationsManager.AddNotificationHook(hook)
|
||||
} else {
|
||||
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking\n")
|
||||
}
|
||||
}
|
||||
|
||||
// setupClusterClientNotificationHook sets up notification hook for cluster clients
|
||||
func setupClusterClientNotificationHook(client *redis.ClusterClient, hook maintnotifications.NotificationHook) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Register hook on existing nodes
|
||||
err := client.ForEachShard(ctx, func(ctx context.Context, nodeClient *redis.Client) error {
|
||||
maintnotificationsManager := nodeClient.GetMaintNotificationsManager()
|
||||
if maintnotificationsManager != nil {
|
||||
maintnotificationsManager.AddNotificationHook(hook)
|
||||
} else {
|
||||
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking on node: %s\n", nodeClient.Options().Addr)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("[TNH] Warning: Failed to register timeout tracking hooks on existing cluster nodes: %v\n", err)
|
||||
}
|
||||
|
||||
// Register hook on new nodes
|
||||
client.OnNewNode(func(nodeClient *redis.Client) {
|
||||
maintnotificationsManager := nodeClient.GetMaintNotificationsManager()
|
||||
if maintnotificationsManager != nil {
|
||||
maintnotificationsManager.AddNotificationHook(hook)
|
||||
} else {
|
||||
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking on new node: %s\n", nodeClient.Options().Addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// filterPushNotificationLogs filters the diagnostics log for push notification events
|
||||
func filterPushNotificationLogs(diagnosticsLog []DiagnosticsEvent) []DiagnosticsEvent {
|
||||
var pushNotificationLogs []DiagnosticsEvent
|
||||
|
||||
for _, log := range diagnosticsLog {
|
||||
switch log.Type {
|
||||
case "MOVING", "MIGRATING", "MIGRATED":
|
||||
pushNotificationLogs = append(pushNotificationLogs, log)
|
||||
}
|
||||
}
|
||||
|
||||
return pushNotificationLogs
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) GetAnalysis() *DiagnosticsAnalysis {
|
||||
return NewDiagnosticsAnalysis(tnh.GetDiagnosticsLog())
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) GetDiagnosticsLogForConn(connID uint64) []DiagnosticsEvent {
|
||||
tnh.mutex.RLock()
|
||||
defer tnh.mutex.RUnlock()
|
||||
|
||||
var connLogs []DiagnosticsEvent
|
||||
for _, log := range tnh.diagnosticsLog {
|
||||
if log.ConnID == connID {
|
||||
connLogs = append(connLogs, log)
|
||||
}
|
||||
}
|
||||
return connLogs
|
||||
}
|
||||
|
||||
func (tnh *TrackingNotificationsHook) GetAnalysisForConn(connID uint64) *DiagnosticsAnalysis {
|
||||
return NewDiagnosticsAnalysis(tnh.GetDiagnosticsLogForConn(connID))
|
||||
}
|
||||
|
||||
type DiagnosticsAnalysis struct {
|
||||
RelaxedTimeoutCount int64
|
||||
UnrelaxedTimeoutCount int64
|
||||
NotificationProcessingErrors int64
|
||||
ConnectionCount int64
|
||||
MovingCount int64
|
||||
MigratingCount int64
|
||||
MigratedCount int64
|
||||
FailingOverCount int64
|
||||
FailedOverCount int64
|
||||
UnexpectedNotificationCount int64
|
||||
TotalNotifications int64
|
||||
diagnosticsLog []DiagnosticsEvent
|
||||
connLogs map[uint64][]DiagnosticsEvent
|
||||
connIds map[uint64]bool
|
||||
}
|
||||
|
||||
func NewDiagnosticsAnalysis(diagnosticsLog []DiagnosticsEvent) *DiagnosticsAnalysis {
|
||||
da := &DiagnosticsAnalysis{
|
||||
diagnosticsLog: diagnosticsLog,
|
||||
connLogs: make(map[uint64][]DiagnosticsEvent),
|
||||
connIds: make(map[uint64]bool),
|
||||
}
|
||||
|
||||
da.Analyze()
|
||||
return da
|
||||
}
|
||||
|
||||
func (da *DiagnosticsAnalysis) Analyze() {
|
||||
for _, log := range da.diagnosticsLog {
|
||||
da.TotalNotifications++
|
||||
switch log.Type {
|
||||
case "MOVING":
|
||||
da.MovingCount++
|
||||
case "MIGRATING":
|
||||
da.MigratingCount++
|
||||
case "MIGRATED":
|
||||
da.MigratedCount++
|
||||
case "FAILING_OVER":
|
||||
da.FailingOverCount++
|
||||
case "FAILED_OVER":
|
||||
da.FailedOverCount++
|
||||
default:
|
||||
da.UnexpectedNotificationCount++
|
||||
}
|
||||
if log.Error != nil {
|
||||
fmt.Printf("[ERROR] Notification processing error: %v\n", log.Error)
|
||||
fmt.Printf("[ERROR] Notification: %v\n", log.Details["notification"])
|
||||
fmt.Printf("[ERROR] Context: %v\n", log.Details["context"])
|
||||
da.NotificationProcessingErrors++
|
||||
}
|
||||
if log.Type == "MIGRATING" || log.Type == "FAILING_OVER" {
|
||||
da.RelaxedTimeoutCount++
|
||||
} else if log.Type == "MIGRATED" || log.Type == "FAILED_OVER" {
|
||||
da.UnrelaxedTimeoutCount++
|
||||
}
|
||||
if log.ConnID != 0 {
|
||||
if v, ok := da.connIds[log.ConnID]; !ok || !v {
|
||||
da.connIds[log.ConnID] = true
|
||||
da.connLogs[log.ConnID] = make([]DiagnosticsEvent, 0)
|
||||
da.ConnectionCount++
|
||||
}
|
||||
da.connLogs[log.ConnID] = append(da.connLogs[log.ConnID], log)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (a *DiagnosticsAnalysis) Print(t *testing.T) {
|
||||
t.Logf("Notification Analysis results for %d events and %d connections:", len(a.diagnosticsLog), len(a.connIds))
|
||||
t.Logf("-------------")
|
||||
t.Logf("-Timeout Analysis based on type of notification-")
|
||||
t.Logf("Note: MIGRATED and FAILED_OVER notifications are not tracked by the hook, so they are not included in the relaxed/unrelaxed count")
|
||||
t.Logf("Note: The hook only tracks timeouts that occur after the notification is processed, so timeouts that occur during processing are not included")
|
||||
t.Logf("-------------")
|
||||
t.Logf(" - Relaxed Timeout Count: %d", a.RelaxedTimeoutCount)
|
||||
t.Logf(" - Unrelaxed Timeout Count: %d", a.UnrelaxedTimeoutCount)
|
||||
t.Logf("-------------")
|
||||
t.Logf("-Notification Analysis-")
|
||||
t.Logf("-------------")
|
||||
t.Logf(" - MOVING: %d", a.MovingCount)
|
||||
t.Logf(" - MIGRATING: %d", a.MigratingCount)
|
||||
t.Logf(" - MIGRATED: %d", a.MigratedCount)
|
||||
t.Logf(" - FAILING_OVER: %d", a.FailingOverCount)
|
||||
t.Logf(" - FAILED_OVER: %d", a.FailedOverCount)
|
||||
t.Logf(" - Unexpected: %d", a.UnexpectedNotificationCount)
|
||||
t.Logf("-------------")
|
||||
t.Logf(" - Total Notifications: %d", a.TotalNotifications)
|
||||
t.Logf(" - Notification Processing Errors: %d", a.NotificationProcessingErrors)
|
||||
t.Logf(" - Connection Count: %d", a.ConnectionCount)
|
||||
t.Logf("-------------")
|
||||
t.Logf("Diagnostics Analysis completed successfully")
|
||||
}
|
||||
374
maintnotifications/e2e/scenario_endpoint_types_test.go
Normal file
374
maintnotifications/e2e/scenario_endpoint_types_test.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
// TestEndpointTypesPushNotifications tests push notifications with different endpoint types
|
||||
func TestEndpointTypesPushNotifications(t *testing.T) {
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var dump = true
|
||||
var errorsDetected = false
|
||||
|
||||
// Test different endpoint types
|
||||
endpointTypes := []struct {
|
||||
name string
|
||||
endpointType maintnotifications.EndpointType
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ExternalIP",
|
||||
endpointType: maintnotifications.EndpointTypeExternalIP,
|
||||
description: "External IP endpoint type for enterprise clusters",
|
||||
},
|
||||
{
|
||||
name: "ExternalFQDN",
|
||||
endpointType: maintnotifications.EndpointTypeExternalFQDN,
|
||||
description: "External FQDN endpoint type for DNS-based routing",
|
||||
},
|
||||
{
|
||||
name: "None",
|
||||
endpointType: maintnotifications.EndpointTypeNone,
|
||||
description: "No endpoint type - reconnect with current config",
|
||||
},
|
||||
}
|
||||
|
||||
defer func() {
|
||||
logCollector.Clear()
|
||||
}()
|
||||
|
||||
// Test each endpoint type with its own fresh database
|
||||
for _, endpointTest := range endpointTypes {
|
||||
t.Run(endpointTest.name, func(t *testing.T) {
|
||||
// Setup: Create fresh database and client factory for THIS endpoint type test
|
||||
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
|
||||
defer cleanup()
|
||||
t.Logf("[ENDPOINT-TYPES-%s] Created test database with bdb_id: %d", endpointTest.name, bdbID)
|
||||
|
||||
// Create fault injector
|
||||
faultInjector, err := CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
t.Fatalf("[ERROR] Failed to create fault injector: %v", err)
|
||||
}
|
||||
|
||||
// Get endpoint config from factory (now connected to new database)
|
||||
endpointConfig := factory.GetConfig()
|
||||
|
||||
defer func() {
|
||||
if dump {
|
||||
fmt.Println("Pool stats:")
|
||||
factory.PrintPoolStats(t)
|
||||
}
|
||||
}()
|
||||
// Clear logs between endpoint type tests
|
||||
logCollector.Clear()
|
||||
// reset errors detected flag
|
||||
errorsDetected = false
|
||||
// reset dump flag
|
||||
dump = true
|
||||
// redefine p and e for each test to get
|
||||
// proper test name in logs and proper test failures
|
||||
var p = func(format string, args ...interface{}) {
|
||||
printLog("ENDPOINT-TYPES", false, format, args...)
|
||||
}
|
||||
|
||||
var e = func(format string, args ...interface{}) {
|
||||
errorsDetected = true
|
||||
printLog("ENDPOINT-TYPES", true, format, args...)
|
||||
}
|
||||
|
||||
var ef = func(format string, args ...interface{}) {
|
||||
printLog("ENDPOINT-TYPES", true, format, args...)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
p("Testing endpoint type: %s - %s", endpointTest.name, endpointTest.description)
|
||||
|
||||
minIdleConns := 3
|
||||
poolSize := 8
|
||||
maxConnections := 12
|
||||
|
||||
// Create Redis client with specific endpoint type
|
||||
client, err := factory.Create(fmt.Sprintf("endpoint-test-%s", endpointTest.name), &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: 30 * time.Second,
|
||||
RelaxedTimeout: 8 * time.Second,
|
||||
PostHandoffRelaxedDuration: 2 * time.Second,
|
||||
MaxWorkers: 15,
|
||||
EndpointType: endpointTest.endpointType, // Test specific endpoint type
|
||||
},
|
||||
ClientName: fmt.Sprintf("endpoint-test-%s", endpointTest.name),
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to create client for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
|
||||
// Create timeout tracker
|
||||
tracker := NewTrackingNotificationsHook()
|
||||
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
setupNotificationHooks(client, tracker, logger)
|
||||
defer func() {
|
||||
tracker.Clear()
|
||||
}()
|
||||
|
||||
// Verify initial connectivity
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
ef("Failed to ping Redis with %s endpoint type: %v", endpointTest.name, err)
|
||||
}
|
||||
|
||||
p("Client connected successfully with %s endpoint type", endpointTest.name)
|
||||
|
||||
commandsRunner, _ := NewCommandRunner(client)
|
||||
defer func() {
|
||||
if dump {
|
||||
stats := commandsRunner.GetStats()
|
||||
p("%s endpoint stats: Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
endpointTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
|
||||
}
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
|
||||
// Test failover with this endpoint type
|
||||
p("Testing failover with %s endpoint type on database [bdb_id:%s]...", endpointTest.name, endpointConfig.BdbID)
|
||||
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "failover",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger failover action for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
|
||||
// Start command traffic
|
||||
go func() {
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
}()
|
||||
|
||||
// Wait for failover to complete
|
||||
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Failover action failed for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
p("[FI] Failover action completed for %s: %s %s", endpointTest.name, status.Status, actionOutputIfFailed(status))
|
||||
|
||||
// Wait for FAILING_OVER notification
|
||||
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("FAILING_OVER notification was not received for %s endpoint type", endpointTest.name)
|
||||
}
|
||||
failingOverData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("FAILING_OVER notification received for %s. %v", endpointTest.name, failingOverData)
|
||||
|
||||
// Wait for FAILED_OVER notification
|
||||
seqIDToObserve := int64(failingOverData["seqID"].(float64))
|
||||
connIDToObserve := uint64(failingOverData["connID"].(float64))
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("FAILED_OVER notification was not received for %s endpoint type", endpointTest.name)
|
||||
}
|
||||
failedOverData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("FAILED_OVER notification received for %s. %v", endpointTest.name, failedOverData)
|
||||
|
||||
// Test migration with this endpoint type
|
||||
p("Testing migration with %s endpoint type...", endpointTest.name)
|
||||
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "migrate",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger migrate action for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
|
||||
// Wait for migration to complete
|
||||
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Migrate action failed for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
p("[FI] Migrate action completed for %s: %s %s", endpointTest.name, status.Status, actionOutputIfFailed(status))
|
||||
|
||||
// Wait for MIGRATING notification
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
|
||||
}, 60*time.Second)
|
||||
if !found {
|
||||
ef("MIGRATING notification was not received for %s endpoint type", endpointTest.name)
|
||||
}
|
||||
migrateData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MIGRATING notification received for %s: %v", endpointTest.name, migrateData)
|
||||
|
||||
// Wait for MIGRATED notification
|
||||
seqIDToObserve = int64(migrateData["seqID"].(float64))
|
||||
connIDToObserve = uint64(migrateData["connID"].(float64))
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return notificationType(s, "MIGRATED") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("MIGRATED notification was not received for %s endpoint type", endpointTest.name)
|
||||
}
|
||||
migratedData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MIGRATED notification received for %s. %v", endpointTest.name, migratedData)
|
||||
|
||||
// Complete migration with bind action
|
||||
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "bind",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger bind action for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
|
||||
// Wait for MOVING notification
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("MOVING notification was not received for %s endpoint type", endpointTest.name)
|
||||
}
|
||||
movingData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MOVING notification received for %s. %v", endpointTest.name, movingData)
|
||||
|
||||
notification, ok := movingData["notification"].(string)
|
||||
if !ok {
|
||||
e("invalid notification message")
|
||||
}
|
||||
|
||||
notification = notification[:len(notification)-1]
|
||||
notificationParts := strings.Split(notification, " ")
|
||||
address := notificationParts[len(notificationParts)-1]
|
||||
|
||||
switch endpointTest.endpointType {
|
||||
case maintnotifications.EndpointTypeExternalFQDN:
|
||||
address = strings.Split(address, ":")[0]
|
||||
addressParts := strings.SplitN(address, ".", 2)
|
||||
if len(addressParts) != 2 {
|
||||
e("invalid address %s", address)
|
||||
} else {
|
||||
address = addressParts[1]
|
||||
}
|
||||
|
||||
var expectedAddress string
|
||||
hostParts := strings.SplitN(endpointConfig.Host, ".", 2)
|
||||
if len(hostParts) != 2 {
|
||||
e("invalid host %s", endpointConfig.Host)
|
||||
} else {
|
||||
expectedAddress = hostParts[1]
|
||||
}
|
||||
if address != expectedAddress {
|
||||
e("invalid fqdn, expected: %s, got: %s", expectedAddress, address)
|
||||
}
|
||||
|
||||
case maintnotifications.EndpointTypeExternalIP:
|
||||
address = strings.Split(address, ":")[0]
|
||||
ip := net.ParseIP(address)
|
||||
if ip == nil {
|
||||
e("invalid message format, expected valid IP, got: %s", address)
|
||||
}
|
||||
case maintnotifications.EndpointTypeNone:
|
||||
if address != internal.RedisNull {
|
||||
e("invalid endpoint type, expected: %s, got: %s", internal.RedisNull, address)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for bind to complete
|
||||
bindStatus, err := faultInjector.WaitForAction(ctx, bindResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second))
|
||||
if err != nil {
|
||||
ef("Bind action failed for %s: %v", endpointTest.name, err)
|
||||
}
|
||||
p("Bind action completed for %s: %s %s", endpointTest.name, bindStatus.Status, actionOutputIfFailed(bindStatus))
|
||||
|
||||
// Continue traffic for analysis
|
||||
time.Sleep(30 * time.Second)
|
||||
commandsRunner.Stop()
|
||||
|
||||
// Analyze results for this endpoint type
|
||||
trackerAnalysis := tracker.GetAnalysis()
|
||||
if trackerAnalysis.NotificationProcessingErrors > 0 {
|
||||
e("Notification processing errors with %s endpoint type: %d", endpointTest.name, trackerAnalysis.NotificationProcessingErrors)
|
||||
}
|
||||
|
||||
if trackerAnalysis.UnexpectedNotificationCount > 0 {
|
||||
e("Unexpected notifications with %s endpoint type: %d", endpointTest.name, trackerAnalysis.UnexpectedNotificationCount)
|
||||
}
|
||||
|
||||
// Validate we received all expected notification types
|
||||
if trackerAnalysis.FailingOverCount == 0 {
|
||||
e("Expected FAILING_OVER notifications with %s endpoint type, got none", endpointTest.name)
|
||||
}
|
||||
if trackerAnalysis.FailedOverCount == 0 {
|
||||
e("Expected FAILED_OVER notifications with %s endpoint type, got none", endpointTest.name)
|
||||
}
|
||||
if trackerAnalysis.MigratingCount == 0 {
|
||||
e("Expected MIGRATING notifications with %s endpoint type, got none", endpointTest.name)
|
||||
}
|
||||
if trackerAnalysis.MigratedCount == 0 {
|
||||
e("Expected MIGRATED notifications with %s endpoint type, got none", endpointTest.name)
|
||||
}
|
||||
if trackerAnalysis.MovingCount == 0 {
|
||||
e("Expected MOVING notifications with %s endpoint type, got none", endpointTest.name)
|
||||
}
|
||||
|
||||
logAnalysis := logCollector.GetAnalysis()
|
||||
if logAnalysis.TotalHandoffCount == 0 {
|
||||
e("Expected at least one handoff with %s endpoint type, got none", endpointTest.name)
|
||||
}
|
||||
if logAnalysis.TotalHandoffCount != logAnalysis.SucceededHandoffCount {
|
||||
e("Expected all handoffs to succeed with %s endpoint type, got %d failed", endpointTest.name, logAnalysis.FailedHandoffCount)
|
||||
}
|
||||
|
||||
if errorsDetected {
|
||||
logCollector.DumpLogs()
|
||||
trackerAnalysis.Print(t)
|
||||
logCollector.Clear()
|
||||
tracker.Clear()
|
||||
ef("[FAIL] Errors detected with %s endpoint type", endpointTest.name)
|
||||
}
|
||||
p("Endpoint type %s test completed successfully", endpointTest.name)
|
||||
logCollector.GetAnalysis().Print(t)
|
||||
trackerAnalysis.Print(t)
|
||||
logCollector.Clear()
|
||||
tracker.Clear()
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("All endpoint types tested successfully")
|
||||
}
|
||||
513
maintnotifications/e2e/scenario_push_notifications_test.go
Normal file
513
maintnotifications/e2e/scenario_push_notifications_test.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
// TestPushNotifications tests Redis Enterprise push notifications (MOVING, MIGRATING, MIGRATED)
|
||||
func TestPushNotifications(t *testing.T) {
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Setup: Create fresh database and client factory for this test
|
||||
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
|
||||
defer cleanup()
|
||||
t.Logf("[PUSH-NOTIFICATIONS] Created test database with bdb_id: %d", bdbID)
|
||||
|
||||
// Wait for database to be fully ready
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
var dump = true
|
||||
var seqIDToObserve int64
|
||||
var connIDToObserve uint64
|
||||
|
||||
var match string
|
||||
var found bool
|
||||
|
||||
var status *ActionStatusResponse
|
||||
var errorsDetected = false
|
||||
|
||||
var p = func(format string, args ...interface{}) {
|
||||
printLog("PUSH-NOTIFICATIONS", false, format, args...)
|
||||
}
|
||||
|
||||
var e = func(format string, args ...interface{}) {
|
||||
errorsDetected = true
|
||||
printLog("PUSH-NOTIFICATIONS", true, format, args...)
|
||||
}
|
||||
|
||||
var ef = func(format string, args ...interface{}) {
|
||||
printLog("PUSH-NOTIFICATIONS", true, format, args...)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
logCollector.ClearLogs()
|
||||
defer func() {
|
||||
logCollector.Clear()
|
||||
}()
|
||||
|
||||
// Get endpoint config from factory (now connected to new database)
|
||||
endpointConfig := factory.GetConfig()
|
||||
|
||||
// Create fault injector
|
||||
faultInjector, err := CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
ef("Failed to create fault injector: %v", err)
|
||||
}
|
||||
|
||||
minIdleConns := 5
|
||||
poolSize := 10
|
||||
maxConnections := 15
|
||||
// Create Redis client with push notifications enabled
|
||||
client, err := factory.Create("push-notification-client", &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: 40 * time.Second, // 30 seconds
|
||||
RelaxedTimeout: 10 * time.Second, // 10 seconds relaxed timeout
|
||||
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
|
||||
MaxWorkers: 20,
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
|
||||
},
|
||||
ClientName: "push-notification-test-client",
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
factory.DestroyAll()
|
||||
}()
|
||||
|
||||
// Create timeout tracker
|
||||
tracker := NewTrackingNotificationsHook()
|
||||
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
setupNotificationHooks(client, tracker, logger)
|
||||
defer func() {
|
||||
tracker.Clear()
|
||||
}()
|
||||
|
||||
// Verify initial connectivity
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
ef("Failed to ping Redis: %v", err)
|
||||
}
|
||||
|
||||
p("Client connected successfully, starting push notification test")
|
||||
|
||||
commandsRunner, _ := NewCommandRunner(client)
|
||||
defer func() {
|
||||
if dump {
|
||||
p("Command runner stats:")
|
||||
p("Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
|
||||
}
|
||||
p("Stopping command runner...")
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
|
||||
p("Starting FAILING_OVER / FAILED_OVER notifications test...")
|
||||
// Test: Trigger failover action to generate FAILING_OVER, FAILED_OVER notifications
|
||||
p("Triggering failover action to generate push notifications...")
|
||||
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "failover",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger failover action: %v", err)
|
||||
}
|
||||
go func() {
|
||||
p("Waiting for FAILING_OVER notification")
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
|
||||
}, 3*time.Minute)
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
if !found {
|
||||
ef("FAILING_OVER notification was not received within 3 minutes")
|
||||
}
|
||||
failingOverData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("FAILING_OVER notification received. %v", failingOverData)
|
||||
seqIDToObserve = int64(failingOverData["seqID"].(float64))
|
||||
connIDToObserve = uint64(failingOverData["connID"].(float64))
|
||||
go func() {
|
||||
p("Waiting for FAILED_OVER notification on conn %d with seqID %d...", connIDToObserve, seqIDToObserve+1)
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
|
||||
}, 3*time.Minute)
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
if !found {
|
||||
ef("FAILED_OVER notification was not received within 3 minutes")
|
||||
}
|
||||
failedOverData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("FAILED_OVER notification received. %v", failedOverData)
|
||||
|
||||
status, err = faultInjector.WaitForAction(ctx, failoverResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Failover action failed: %v", err)
|
||||
}
|
||||
p("[FI] Failover action completed: %v %s", status.Status, actionOutputIfFailed(status))
|
||||
|
||||
p("FAILING_OVER / FAILED_OVER notifications test completed successfully")
|
||||
|
||||
// Test: Trigger migrate action to generate MOVING, MIGRATING, MIGRATED notifications
|
||||
p("Triggering migrate action to generate push notifications...")
|
||||
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "migrate",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger migrate action: %v", err)
|
||||
}
|
||||
go func() {
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
|
||||
}, 60*time.Second)
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
if !found {
|
||||
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Migrate action failed: %v", err)
|
||||
}
|
||||
p("[FI] Migrate action completed: %s %s", status.Status, actionOutputIfFailed(status))
|
||||
ef("MIGRATING notification for migrate action was not received within 60 seconds")
|
||||
}
|
||||
migrateData := logs2.ExtractDataFromLogMessage(match)
|
||||
seqIDToObserve = int64(migrateData["seqID"].(float64))
|
||||
connIDToObserve = uint64(migrateData["connID"].(float64))
|
||||
p("MIGRATING notification received: seqID: %d, connID: %d", seqIDToObserve, connIDToObserve)
|
||||
|
||||
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Migrate action failed: %v", err)
|
||||
}
|
||||
p("[FI] Migrate action completed: %s %s", status.Status, actionOutputIfFailed(status))
|
||||
|
||||
go func() {
|
||||
p("Waiting for MIGRATED notification on conn %d with seqID %d...", connIDToObserve, seqIDToObserve+1)
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return notificationType(s, "MIGRATED") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
|
||||
}, 3*time.Minute)
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
if !found {
|
||||
ef("MIGRATED notification was not received within 3 minutes")
|
||||
}
|
||||
migratedData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MIGRATED notification received. %v", migratedData)
|
||||
|
||||
p("MIGRATING / MIGRATED notifications test completed successfully")
|
||||
|
||||
// Trigger bind action to complete the migration process
|
||||
p("Triggering bind action to complete migration...")
|
||||
|
||||
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "bind",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger bind action: %v", err)
|
||||
}
|
||||
|
||||
// start a second client but don't execute any commands on it
|
||||
p("Starting a second client to observe notification during moving...")
|
||||
client2, err := factory.Create("push-notification-client-2", &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: 40 * time.Second, // 30 seconds
|
||||
RelaxedTimeout: 30 * time.Minute, // 30 minutes relaxed timeout for second client
|
||||
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
|
||||
MaxWorkers: 20,
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
|
||||
},
|
||||
ClientName: "push-notification-test-client-2",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
ef("failed to create client: %v", err)
|
||||
}
|
||||
// setup tracking for second client
|
||||
tracker2 := NewTrackingNotificationsHook()
|
||||
logger2 := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
setupNotificationHooks(client2, tracker2, logger2)
|
||||
commandsRunner2, _ := NewCommandRunner(client2)
|
||||
p("Second client created")
|
||||
|
||||
// Use a channel to communicate errors from the goroutine
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("goroutine panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
p("Waiting for MOVING notification on first client")
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
|
||||
}, 3*time.Minute)
|
||||
commandsRunner.Stop()
|
||||
if !found {
|
||||
errChan <- fmt.Errorf("MOVING notification was not received within 3 minutes ON A FIRST CLIENT")
|
||||
return
|
||||
}
|
||||
|
||||
// once moving is received, start a second client commands runner
|
||||
p("Starting commands on second client")
|
||||
go commandsRunner2.FireCommandsUntilStop(ctx)
|
||||
defer func() {
|
||||
// stop the second runner
|
||||
commandsRunner2.Stop()
|
||||
// destroy the second client
|
||||
factory.Destroy("push-notification-client-2")
|
||||
}()
|
||||
|
||||
p("Waiting for MOVING notification on second client")
|
||||
matchNotif, fnd := tracker2.FindOrWaitForNotification("MOVING", 3*time.Minute)
|
||||
if !fnd {
|
||||
errChan <- fmt.Errorf("MOVING notification was not received within 3 minutes ON A SECOND CLIENT")
|
||||
return
|
||||
} else {
|
||||
p("MOVING notification received on second client %v", matchNotif)
|
||||
}
|
||||
|
||||
// Signal success
|
||||
errChan <- nil
|
||||
}()
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
// wait for moving on first client
|
||||
// once the commandRunner stops, it means a waiting
|
||||
// on the logCollector match has completed and we can proceed
|
||||
if !found {
|
||||
ef("MOVING notification was not received within 3 minutes")
|
||||
}
|
||||
movingData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MOVING notification received. %v", movingData)
|
||||
seqIDToObserve = int64(movingData["seqID"].(float64))
|
||||
connIDToObserve = uint64(movingData["connID"].(float64))
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// start a third client but don't execute any commands on it
|
||||
p("Starting a third client to observe notification during moving...")
|
||||
client3, err := factory.Create("push-notification-client-2", &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: 40 * time.Second, // 30 seconds
|
||||
RelaxedTimeout: 30 * time.Minute, // 30 minutes relaxed timeout for second client
|
||||
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
|
||||
MaxWorkers: 20,
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
|
||||
},
|
||||
ClientName: "push-notification-test-client-3",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
ef("failed to create client: %v", err)
|
||||
}
|
||||
// setup tracking for second client
|
||||
tracker3 := NewTrackingNotificationsHook()
|
||||
logger3 := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
setupNotificationHooks(client3, tracker3, logger3)
|
||||
commandsRunner3, _ := NewCommandRunner(client3)
|
||||
p("Third client created")
|
||||
go commandsRunner3.FireCommandsUntilStop(ctx)
|
||||
// wait for moving on third client
|
||||
movingNotification, found := tracker3.FindOrWaitForNotification("MOVING", 3*time.Minute)
|
||||
if !found {
|
||||
p("[NOTICE] MOVING notification was not received within 3 minutes ON A THIRD CLIENT")
|
||||
} else {
|
||||
p("MOVING notification received on third client. %v", movingNotification)
|
||||
if len(movingNotification) != 4 {
|
||||
p("[NOTICE] Invalid MOVING notification format: %s", movingNotification)
|
||||
}
|
||||
mNotifTimeS, ok := movingNotification[2].(int64)
|
||||
if !ok {
|
||||
p("[NOTICE] Invalid timeS in MOVING notification: %s", movingNotification)
|
||||
}
|
||||
// expect timeS to be less than 15
|
||||
if mNotifTimeS < 15 {
|
||||
p("[NOTICE] Expected timeS < 15, got %d", mNotifTimeS)
|
||||
}
|
||||
}
|
||||
commandsRunner3.Stop()
|
||||
// Wait for the goroutine to complete and check for errors
|
||||
if err := <-errChan; err != nil {
|
||||
ef("Second client goroutine error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for bind action to complete
|
||||
bindStatus, err := faultInjector.WaitForAction(ctx, bindResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second))
|
||||
if err != nil {
|
||||
ef("Bind action failed: %v", err)
|
||||
}
|
||||
|
||||
p("Bind action completed: %s %s", bindStatus.Status, actionOutputIfFailed(bindStatus))
|
||||
|
||||
p("MOVING notification test completed successfully")
|
||||
|
||||
p("Executing commands and collecting logs for analysis... This will take 30 seconds...")
|
||||
go commandsRunner.FireCommandsUntilStop(ctx)
|
||||
time.Sleep(30 * time.Second)
|
||||
commandsRunner.Stop()
|
||||
allLogsAnalysis := logCollector.GetAnalysis()
|
||||
trackerAnalysis := tracker.GetAnalysis()
|
||||
|
||||
if allLogsAnalysis.TimeoutErrorsCount > 0 {
|
||||
e("Unexpected timeout errors: %d", allLogsAnalysis.TimeoutErrorsCount)
|
||||
}
|
||||
if trackerAnalysis.UnexpectedNotificationCount > 0 {
|
||||
e("Unexpected notifications: %d", trackerAnalysis.UnexpectedNotificationCount)
|
||||
}
|
||||
if trackerAnalysis.NotificationProcessingErrors > 0 {
|
||||
e("Notification processing errors: %d", trackerAnalysis.NotificationProcessingErrors)
|
||||
}
|
||||
if allLogsAnalysis.RelaxedTimeoutCount == 0 {
|
||||
e("Expected relaxed timeouts, got none")
|
||||
}
|
||||
if allLogsAnalysis.UnrelaxedTimeoutCount == 0 {
|
||||
e("Expected unrelaxed timeouts, got none")
|
||||
}
|
||||
if allLogsAnalysis.UnrelaxedAfterMoving == 0 {
|
||||
e("Expected unrelaxed timeouts after moving, got none")
|
||||
}
|
||||
if allLogsAnalysis.RelaxedPostHandoffCount == 0 {
|
||||
e("Expected relaxed timeouts after post-handoff, got none")
|
||||
}
|
||||
// validate number of connections we do not exceed max connections
|
||||
// we started three clients, so we expect 3x the connections
|
||||
if allLogsAnalysis.ConnectionCount > int64(maxConnections)*3 {
|
||||
e("Expected no more than %d connections, got %d", maxConnections*3, allLogsAnalysis.ConnectionCount)
|
||||
}
|
||||
|
||||
if allLogsAnalysis.ConnectionCount < int64(minIdleConns) {
|
||||
e("Expected at least %d connections, got %d", minIdleConns, allLogsAnalysis.ConnectionCount)
|
||||
}
|
||||
|
||||
// validate logs are present for all connections
|
||||
for connID := range trackerAnalysis.connIds {
|
||||
if len(allLogsAnalysis.connLogs[connID]) == 0 {
|
||||
e("No logs found for connection %d", connID)
|
||||
}
|
||||
}
|
||||
|
||||
// validate number of notifications in tracker matches number of notifications in logs
|
||||
// allow for more moving in the logs since we started a second client
|
||||
if trackerAnalysis.TotalNotifications > allLogsAnalysis.TotalNotifications {
|
||||
e("Expected %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications)
|
||||
}
|
||||
|
||||
// and per type
|
||||
// allow for more moving in the logs since we started a second client
|
||||
if trackerAnalysis.MovingCount > allLogsAnalysis.MovingCount {
|
||||
e("Expected %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.MigratingCount != allLogsAnalysis.MigratingCount {
|
||||
e("Expected %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.MigratedCount != allLogsAnalysis.MigratedCount {
|
||||
e("Expected %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.FailingOverCount != allLogsAnalysis.FailingOverCount {
|
||||
e("Expected %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.FailedOverCount != allLogsAnalysis.FailedOverCount {
|
||||
e("Expected %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.UnexpectedNotificationCount != allLogsAnalysis.UnexpectedCount {
|
||||
e("Expected %d unexpected notifications, got %d", trackerAnalysis.UnexpectedNotificationCount, allLogsAnalysis.UnexpectedCount)
|
||||
}
|
||||
|
||||
// unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it
|
||||
if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving {
|
||||
e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount)
|
||||
}
|
||||
if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount {
|
||||
e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount)
|
||||
}
|
||||
|
||||
// validate all handoffs succeeded
|
||||
if allLogsAnalysis.FailedHandoffCount > 0 {
|
||||
e("Expected no failed handoffs, got %d", allLogsAnalysis.FailedHandoffCount)
|
||||
}
|
||||
if allLogsAnalysis.SucceededHandoffCount == 0 {
|
||||
e("Expected at least one successful handoff, got none")
|
||||
}
|
||||
if allLogsAnalysis.TotalHandoffCount != allLogsAnalysis.SucceededHandoffCount {
|
||||
e("Expected total handoffs to match successful handoffs, got %d != %d", allLogsAnalysis.TotalHandoffCount, allLogsAnalysis.SucceededHandoffCount)
|
||||
}
|
||||
|
||||
// no additional retries
|
||||
if allLogsAnalysis.TotalHandoffRetries != allLogsAnalysis.TotalHandoffCount {
|
||||
e("Expected no additional handoff retries, got %d", allLogsAnalysis.TotalHandoffRetries-allLogsAnalysis.TotalHandoffCount)
|
||||
}
|
||||
|
||||
if errorsDetected {
|
||||
logCollector.DumpLogs()
|
||||
trackerAnalysis.Print(t)
|
||||
logCollector.Clear()
|
||||
tracker.Clear()
|
||||
ef("[FAIL] Errors detected in push notification test")
|
||||
}
|
||||
|
||||
p("Analysis complete, no errors found")
|
||||
allLogsAnalysis.Print(t)
|
||||
trackerAnalysis.Print(t)
|
||||
p("Command runner stats:")
|
||||
p("Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
|
||||
|
||||
p("Push notification test completed successfully")
|
||||
}
|
||||
311
maintnotifications/e2e/scenario_stress_test.go
Normal file
311
maintnotifications/e2e/scenario_stress_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
// TestStressPushNotifications tests push notifications under extreme stress conditions
|
||||
func TestStressPushNotifications(t *testing.T) {
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
t.Skip("[STRESS][SKIP] Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 35*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Setup: Create fresh database and client factory for this test
|
||||
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
|
||||
defer cleanup()
|
||||
t.Logf("[STRESS] Created test database with bdb_id: %d", bdbID)
|
||||
|
||||
// Wait for database to be fully ready
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
var dump = true
|
||||
var errorsDetected = false
|
||||
|
||||
var p = func(format string, args ...interface{}) {
|
||||
printLog("STRESS", false, format, args...)
|
||||
}
|
||||
|
||||
var e = func(format string, args ...interface{}) {
|
||||
errorsDetected = true
|
||||
printLog("STRESS", true, format, args...)
|
||||
}
|
||||
|
||||
var ef = func(format string, args ...interface{}) {
|
||||
printLog("STRESS", true, format, args...)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
logCollector.ClearLogs()
|
||||
defer func() {
|
||||
logCollector.Clear()
|
||||
}()
|
||||
|
||||
// Get endpoint config from factory (now connected to new database)
|
||||
endpointConfig := factory.GetConfig()
|
||||
|
||||
// Create fault injector
|
||||
faultInjector, err := CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
ef("Failed to create fault injector: %v", err)
|
||||
}
|
||||
|
||||
// Extreme stress configuration
|
||||
minIdleConns := 50
|
||||
poolSize := 150
|
||||
maxConnections := 200
|
||||
numClients := 4
|
||||
|
||||
var clients []redis.UniversalClient
|
||||
var trackers []*TrackingNotificationsHook
|
||||
var commandRunners []*CommandRunner
|
||||
|
||||
// Create multiple clients for extreme stress
|
||||
for i := 0; i < numClients; i++ {
|
||||
client, err := factory.Create(fmt.Sprintf("stress-client-%d", i), &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: 60 * time.Second, // Longer timeout for stress
|
||||
RelaxedTimeout: 20 * time.Second, // Longer relaxed timeout
|
||||
PostHandoffRelaxedDuration: 5 * time.Second, // Longer post-handoff duration
|
||||
MaxWorkers: 50, // Maximum workers for stress
|
||||
HandoffQueueSize: 1000, // Large queue for stress
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP,
|
||||
},
|
||||
ClientName: fmt.Sprintf("stress-test-client-%d", i),
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to create stress client %d: %v", i, err)
|
||||
}
|
||||
clients = append(clients, client)
|
||||
|
||||
// Setup tracking for each client
|
||||
tracker := NewTrackingNotificationsHook()
|
||||
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelWarn)) // Minimal logging for stress
|
||||
setupNotificationHooks(client, tracker, logger)
|
||||
trackers = append(trackers, tracker)
|
||||
|
||||
// Create command runner for each client
|
||||
commandRunner, _ := NewCommandRunner(client)
|
||||
commandRunners = append(commandRunners, commandRunner)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if dump {
|
||||
p("Pool stats:")
|
||||
factory.PrintPoolStats(t)
|
||||
}
|
||||
for _, runner := range commandRunners {
|
||||
runner.Stop()
|
||||
}
|
||||
factory.DestroyAll()
|
||||
}()
|
||||
|
||||
// Verify initial connectivity for all clients
|
||||
for i, client := range clients {
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
ef("Failed to ping Redis with stress client %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
p("All %d stress clients connected successfully", numClients)
|
||||
|
||||
// Start extreme traffic load on all clients
|
||||
var trafficWg sync.WaitGroup
|
||||
for i, runner := range commandRunners {
|
||||
trafficWg.Add(1)
|
||||
go func(clientID int, r *CommandRunner) {
|
||||
defer trafficWg.Done()
|
||||
p("Starting extreme traffic load on stress client %d", clientID)
|
||||
r.FireCommandsUntilStop(ctx)
|
||||
}(i, runner)
|
||||
}
|
||||
|
||||
// Wait for traffic to stabilize
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Trigger multiple concurrent fault injection actions
|
||||
var actionWg sync.WaitGroup
|
||||
var actionResults []string
|
||||
var actionMutex sync.Mutex
|
||||
|
||||
actions := []struct {
|
||||
name string
|
||||
action string
|
||||
delay time.Duration
|
||||
}{
|
||||
{"failover-1", "failover", 0},
|
||||
{"migrate-1", "migrate", 5 * time.Second},
|
||||
{"failover-2", "failover", 10 * time.Second},
|
||||
}
|
||||
|
||||
p("Starting %d concurrent fault injection actions under extreme stress...", len(actions))
|
||||
|
||||
for _, action := range actions {
|
||||
actionWg.Add(1)
|
||||
go func(actionName, actionType string, delay time.Duration) {
|
||||
defer actionWg.Done()
|
||||
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
p("Triggering %s action under extreme stress...", actionName)
|
||||
var resp *ActionResponse
|
||||
var err error
|
||||
|
||||
switch actionType {
|
||||
case "failover":
|
||||
resp, err = faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "failover",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
case "migrate":
|
||||
resp, err = faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "migrate",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
e("Failed to trigger %s action: %v", actionName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for action to complete
|
||||
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
|
||||
WithMaxWaitTime(360*time.Second), // Longer wait time for stress
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
e("[FI] %s action failed: %v", actionName, err)
|
||||
return
|
||||
}
|
||||
|
||||
actionMutex.Lock()
|
||||
actionResults = append(actionResults, fmt.Sprintf("%s: %s %s", actionName, status.Status, actionOutputIfFailed(status)))
|
||||
actionMutex.Unlock()
|
||||
|
||||
p("[FI] %s action completed: %s %s", actionName, status.Status, actionOutputIfFailed(status))
|
||||
}(action.name, action.action, action.delay)
|
||||
}
|
||||
|
||||
// Wait for all actions to complete
|
||||
actionWg.Wait()
|
||||
|
||||
// Continue stress for a bit longer
|
||||
p("All fault injection actions completed, continuing stress for 2 more minutes...")
|
||||
time.Sleep(2 * time.Minute)
|
||||
|
||||
// Stop all command runners
|
||||
for _, runner := range commandRunners {
|
||||
runner.Stop()
|
||||
}
|
||||
trafficWg.Wait()
|
||||
|
||||
// Analyze stress test results
|
||||
allLogsAnalysis := logCollector.GetAnalysis()
|
||||
totalOperations := int64(0)
|
||||
totalErrors := int64(0)
|
||||
totalTimeoutErrors := int64(0)
|
||||
|
||||
for i, runner := range commandRunners {
|
||||
stats := runner.GetStats()
|
||||
p("Stress client %d stats: Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
i, stats.Operations, stats.Errors, stats.TimeoutErrors)
|
||||
totalOperations += stats.Operations
|
||||
totalErrors += stats.Errors
|
||||
totalTimeoutErrors += stats.TimeoutErrors
|
||||
}
|
||||
|
||||
p("STRESS TEST RESULTS:")
|
||||
p("Total operations across all clients: %d", totalOperations)
|
||||
p("Total errors: %d (%.2f%%)", totalErrors, float64(totalErrors)/float64(totalOperations)*100)
|
||||
p("Total timeout errors: %d (%.2f%%)", totalTimeoutErrors, float64(totalTimeoutErrors)/float64(totalOperations)*100)
|
||||
p("Total connections used: %d", allLogsAnalysis.ConnectionCount)
|
||||
|
||||
// Print action results
|
||||
actionMutex.Lock()
|
||||
p("Fault injection action results:")
|
||||
for _, result := range actionResults {
|
||||
p(" %s", result)
|
||||
}
|
||||
actionMutex.Unlock()
|
||||
|
||||
// Validate stress test results
|
||||
if totalOperations < 1000 {
|
||||
e("Expected at least 1000 operations under stress, got %d", totalOperations)
|
||||
}
|
||||
|
||||
// Allow higher error rates under extreme stress (up to 20%)
|
||||
errorRate := float64(totalErrors) / float64(totalOperations) * 100
|
||||
if errorRate > 20.0 {
|
||||
e("Error rate too high under stress: %.2f%% (max allowed: 20%%)", errorRate)
|
||||
}
|
||||
|
||||
// Validate connection limits weren't exceeded
|
||||
expectedMaxConnections := int64(numClients * maxConnections)
|
||||
if allLogsAnalysis.ConnectionCount > expectedMaxConnections {
|
||||
e("Connection count exceeded limit: %d > %d", allLogsAnalysis.ConnectionCount, expectedMaxConnections)
|
||||
}
|
||||
|
||||
// Validate notifications were processed
|
||||
totalTrackerNotifications := int64(0)
|
||||
totalProcessingErrors := int64(0)
|
||||
for _, tracker := range trackers {
|
||||
analysis := tracker.GetAnalysis()
|
||||
totalTrackerNotifications += analysis.TotalNotifications
|
||||
totalProcessingErrors += analysis.NotificationProcessingErrors
|
||||
}
|
||||
|
||||
if totalProcessingErrors > totalTrackerNotifications/10 { // Allow up to 10% processing errors under stress
|
||||
e("Too many notification processing errors under stress: %d/%d", totalProcessingErrors, totalTrackerNotifications)
|
||||
}
|
||||
|
||||
if errorsDetected {
|
||||
ef("Errors detected under stress")
|
||||
logCollector.DumpLogs()
|
||||
for i, tracker := range trackers {
|
||||
p("=== Stress Client %d Analysis ===", i)
|
||||
tracker.GetAnalysis().Print(t)
|
||||
}
|
||||
logCollector.Clear()
|
||||
for _, tracker := range trackers {
|
||||
tracker.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
dump = false
|
||||
p("[SUCCESS] Stress test completed successfully!")
|
||||
p("Processed %d operations across %d clients with %d connections",
|
||||
totalOperations, numClients, allLogsAnalysis.ConnectionCount)
|
||||
p("Error rate: %.2f%%, Notification processing errors: %d/%d",
|
||||
errorRate, totalProcessingErrors, totalTrackerNotifications)
|
||||
|
||||
// Print final analysis
|
||||
allLogsAnalysis.Print(t)
|
||||
for i, tracker := range trackers {
|
||||
p("=== Stress Client %d Analysis ===", i)
|
||||
tracker.GetAnalysis().Print(t)
|
||||
}
|
||||
}
|
||||
245
maintnotifications/e2e/scenario_template.go.example
Normal file
245
maintnotifications/e2e/scenario_template.go.example
Normal file
@@ -0,0 +1,245 @@
|
||||
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
)
|
||||
|
||||
// TestScenarioTemplate is a template for writing scenario tests
|
||||
// Copy this file and rename it to scenario_your_test_name.go
|
||||
func TestScenarioTemplate(t *testing.T) {
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Step 1: Create client factory from configuration
|
||||
factory, err := CreateTestClientFactory("enterprise-cluster") // or "standalone0"
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client factory: %v", err)
|
||||
}
|
||||
defer factory.DestroyAll()
|
||||
|
||||
// Step 2: Create fault injector
|
||||
faultInjector, err := CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create fault injector: %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Create Redis client with hitless upgrades
|
||||
client, err := factory.Create("scenario-client", &CreateClientOptions{
|
||||
Protocol: 3,
|
||||
HitlessUpgradeConfig: &hitless.Config{
|
||||
Mode: hitless.MaintNotificationsEnabled,
|
||||
HandoffTimeout: 30000, // 30 seconds
|
||||
RelaxedTimeout: 10000, // 10 seconds
|
||||
MaxWorkers: 20,
|
||||
},
|
||||
ClientName: "scenario-test-client",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Verify initial connectivity
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping Redis: %v", err)
|
||||
}
|
||||
|
||||
t.Log("Initial setup completed successfully")
|
||||
|
||||
// Step 5: Start background operations (optional)
|
||||
stopCh := make(chan struct{})
|
||||
defer close(stopCh)
|
||||
|
||||
go func() {
|
||||
counter := 0
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
key := fmt.Sprintf("test-key-%d", counter)
|
||||
value := fmt.Sprintf("test-value-%d", counter)
|
||||
|
||||
err := client.Set(ctx, key, value, time.Minute).Err()
|
||||
if err != nil {
|
||||
t.Logf("Background operation failed: %v", err)
|
||||
}
|
||||
|
||||
counter++
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Step 6: Wait for baseline operations
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Step 7: Trigger fault injection scenario
|
||||
t.Log("Triggering fault injection scenario...")
|
||||
|
||||
// Example: Cluster failover
|
||||
// resp, err := faultInjector.TriggerClusterFailover(ctx, "node-1", false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("Failed to trigger failover: %v", err)
|
||||
// }
|
||||
|
||||
// Example: Network latency
|
||||
// nodes := []string{"localhost:7001", "localhost:7002"}
|
||||
// resp, err := faultInjector.SimulateNetworkLatency(ctx, nodes, 100*time.Millisecond, 20*time.Millisecond)
|
||||
// if err != nil {
|
||||
// t.Fatalf("Failed to simulate latency: %v", err)
|
||||
// }
|
||||
|
||||
// Example: Complex sequence
|
||||
// sequence := []SequenceAction{
|
||||
// {
|
||||
// Type: ActionNetworkLatency,
|
||||
// Parameters: map[string]interface{}{
|
||||
// "nodes": []string{"localhost:7001"},
|
||||
// "latency": "50ms",
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Type: ActionClusterFailover,
|
||||
// Parameters: map[string]interface{}{
|
||||
// "node_id": "node-1",
|
||||
// "force": false,
|
||||
// },
|
||||
// Delay: 10 * time.Second,
|
||||
// },
|
||||
// }
|
||||
// resp, err := faultInjector.ExecuteSequence(ctx, sequence)
|
||||
// if err != nil {
|
||||
// t.Fatalf("Failed to execute sequence: %v", err)
|
||||
// }
|
||||
|
||||
// Step 8: Wait for fault injection to complete
|
||||
// status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
|
||||
// WithMaxWaitTime(240*time.Second),
|
||||
// WithPollInterval(2*time.Second))
|
||||
// if err != nil {
|
||||
// t.Fatalf("Fault injection failed: %v", err)
|
||||
// }
|
||||
// t.Logf("Fault injection completed: %s", status.Status)
|
||||
|
||||
// Step 9: Verify client remains operational during and after fault injection
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
t.Errorf("Client not responsive after fault injection: %v", err)
|
||||
}
|
||||
|
||||
// Step 10: Perform additional validation
|
||||
testKey := "validation-key"
|
||||
testValue := "validation-value"
|
||||
|
||||
err = client.Set(ctx, testKey, testValue, time.Minute).Err()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to set validation key: %v", err)
|
||||
}
|
||||
|
||||
retrievedValue, err := client.Get(ctx, testKey).Result()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get validation key: %v", err)
|
||||
} else if retrievedValue != testValue {
|
||||
t.Errorf("Validation failed: expected %s, got %s", testValue, retrievedValue)
|
||||
}
|
||||
|
||||
t.Log("Scenario test completed successfully")
|
||||
}
|
||||
|
||||
// Helper functions for common scenario patterns
|
||||
|
||||
func performContinuousOperations(ctx context.Context, client redis.UniversalClient, workerID int, stopCh <-chan struct{}, errorCh chan<- error) {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
counter := 0
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
key := fmt.Sprintf("worker_%d_key_%d", workerID, counter)
|
||||
value := fmt.Sprintf("value_%d", counter)
|
||||
|
||||
// Perform operation with timeout
|
||||
opCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
err := client.Set(opCtx, key, value, time.Minute).Err()
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case errorCh <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
counter++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateClusterHealth(ctx context.Context, client redis.UniversalClient) error {
|
||||
// Basic connectivity test
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
testKey := "health-check-key"
|
||||
testValue := "health-check-value"
|
||||
|
||||
if err := client.Set(ctx, testKey, testValue, time.Minute).Err(); err != nil {
|
||||
return fmt.Errorf("set operation failed: %w", err)
|
||||
}
|
||||
|
||||
retrievedValue, err := client.Get(ctx, testKey).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get operation failed: %w", err)
|
||||
}
|
||||
|
||||
if retrievedValue != testValue {
|
||||
return fmt.Errorf("value mismatch: expected %s, got %s", testValue, retrievedValue)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
client.Del(ctx, testKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForStableOperations(ctx context.Context, client redis.UniversalClient, duration time.Duration) error {
|
||||
deadline := time.Now().Add(duration)
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
if err := validateClusterHealth(ctx, client); err != nil {
|
||||
return fmt.Errorf("cluster health check failed: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
357
maintnotifications/e2e/scenario_timeout_configs_test.go
Normal file
357
maintnotifications/e2e/scenario_timeout_configs_test.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
// TestTimeoutConfigurationsPushNotifications tests push notifications with different timeout configurations
|
||||
func TestTimeoutConfigurationsPushNotifications(t *testing.T) {
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var dump = true
|
||||
|
||||
var errorsDetected = false
|
||||
var p = func(format string, args ...interface{}) {
|
||||
printLog("TIMEOUT-CONFIGS", false, format, args...)
|
||||
}
|
||||
|
||||
var e = func(format string, args ...interface{}) {
|
||||
errorsDetected = true
|
||||
printLog("TIMEOUT-CONFIGS", true, format, args...)
|
||||
}
|
||||
|
||||
// Test different timeout configurations
|
||||
timeoutConfigs := []struct {
|
||||
name string
|
||||
handoffTimeout time.Duration
|
||||
relaxedTimeout time.Duration
|
||||
postHandoffRelaxedDuration time.Duration
|
||||
description string
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "Conservative",
|
||||
handoffTimeout: 60 * time.Second,
|
||||
relaxedTimeout: 30 * time.Second,
|
||||
postHandoffRelaxedDuration: 2 * time.Minute,
|
||||
description: "Conservative timeouts for stable environments",
|
||||
expectedBehavior: "Longer timeouts, fewer timeout errors",
|
||||
},
|
||||
{
|
||||
name: "Aggressive",
|
||||
handoffTimeout: 5 * time.Second,
|
||||
relaxedTimeout: 3 * time.Second,
|
||||
postHandoffRelaxedDuration: 1 * time.Second,
|
||||
description: "Aggressive timeouts for fast failover",
|
||||
expectedBehavior: "Shorter timeouts, faster recovery",
|
||||
},
|
||||
{
|
||||
name: "HighLatency",
|
||||
handoffTimeout: 90 * time.Second,
|
||||
relaxedTimeout: 30 * time.Second,
|
||||
postHandoffRelaxedDuration: 10 * time.Minute,
|
||||
description: "High latency environment timeouts",
|
||||
expectedBehavior: "Very long timeouts for high latency networks",
|
||||
},
|
||||
}
|
||||
|
||||
logCollector.ClearLogs()
|
||||
defer func() {
|
||||
logCollector.Clear()
|
||||
}()
|
||||
|
||||
// Test each timeout configuration with its own fresh database
|
||||
for _, timeoutTest := range timeoutConfigs {
|
||||
t.Run(timeoutTest.name, func(t *testing.T) {
|
||||
// Setup: Create fresh database and client factory for THIS timeout config test
|
||||
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
|
||||
defer cleanup()
|
||||
t.Logf("[TIMEOUT-CONFIGS-%s] Created test database with bdb_id: %d", timeoutTest.name, bdbID)
|
||||
|
||||
// Get endpoint config from factory (now connected to new database)
|
||||
endpointConfig := factory.GetConfig()
|
||||
|
||||
// Create fault injector
|
||||
faultInjector, err := CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
t.Fatalf("[ERROR] Failed to create fault injector: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if dump {
|
||||
p("Pool stats:")
|
||||
factory.PrintPoolStats(t)
|
||||
}
|
||||
}()
|
||||
|
||||
errorsDetected = false
|
||||
var ef = func(format string, args ...interface{}) {
|
||||
printLog("TIMEOUT-CONFIGS", true, format, args...)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
p("Testing timeout configuration: %s - %s", timeoutTest.name, timeoutTest.description)
|
||||
p("Expected behavior: %s", timeoutTest.expectedBehavior)
|
||||
p("Handoff timeout: %v, Relaxed timeout: %v, Post-handoff duration: %v",
|
||||
timeoutTest.handoffTimeout, timeoutTest.relaxedTimeout, timeoutTest.postHandoffRelaxedDuration)
|
||||
|
||||
minIdleConns := 4
|
||||
poolSize := 10
|
||||
maxConnections := 15
|
||||
|
||||
// Create Redis client with specific timeout configuration
|
||||
client, err := factory.Create(fmt.Sprintf("timeout-test-%s", timeoutTest.name), &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: timeoutTest.handoffTimeout,
|
||||
RelaxedTimeout: timeoutTest.relaxedTimeout,
|
||||
PostHandoffRelaxedDuration: timeoutTest.postHandoffRelaxedDuration,
|
||||
MaxWorkers: 20,
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP,
|
||||
},
|
||||
ClientName: fmt.Sprintf("timeout-test-%s", timeoutTest.name),
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to create client for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
|
||||
// Create timeout tracker
|
||||
tracker := NewTrackingNotificationsHook()
|
||||
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
setupNotificationHooks(client, tracker, logger)
|
||||
defer func() {
|
||||
tracker.Clear()
|
||||
}()
|
||||
|
||||
// Verify initial connectivity
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
ef("Failed to ping Redis with %s timeout config: %v", timeoutTest.name, err)
|
||||
}
|
||||
|
||||
p("Client connected successfully with %s timeout configuration", timeoutTest.name)
|
||||
|
||||
commandsRunner, _ := NewCommandRunner(client)
|
||||
defer func() {
|
||||
if dump {
|
||||
stats := commandsRunner.GetStats()
|
||||
p("%s timeout config stats: Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
timeoutTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
|
||||
}
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
|
||||
// Start command traffic
|
||||
go func() {
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
}()
|
||||
|
||||
// Record start time for timeout analysis
|
||||
testStartTime := time.Now()
|
||||
|
||||
// Test failover with this timeout configuration
|
||||
p("Testing failover with %s timeout configuration...", timeoutTest.name)
|
||||
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "failover",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger failover action for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
|
||||
// Wait for FAILING_OVER notification
|
||||
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("FAILING_OVER notification was not received for %s timeout config", timeoutTest.name)
|
||||
}
|
||||
failingOverData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("FAILING_OVER notification received for %s. %v", timeoutTest.name, failingOverData)
|
||||
|
||||
// Wait for FAILED_OVER notification
|
||||
seqIDToObserve := int64(failingOverData["seqID"].(float64))
|
||||
connIDToObserve := uint64(failingOverData["connID"].(float64))
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("FAILED_OVER notification was not received for %s timeout config", timeoutTest.name)
|
||||
}
|
||||
failedOverData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("FAILED_OVER notification received for %s. %v", timeoutTest.name, failedOverData)
|
||||
|
||||
// Wait for failover to complete
|
||||
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
|
||||
WithMaxWaitTime(180*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Failover action failed for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
p("[FI] Failover action completed for %s: %s %s", timeoutTest.name, status.Status, actionOutputIfFailed(status))
|
||||
|
||||
// Continue traffic to observe timeout behavior
|
||||
p("Continuing traffic for %v to observe timeout behavior...", timeoutTest.relaxedTimeout*2)
|
||||
time.Sleep(timeoutTest.relaxedTimeout * 2)
|
||||
|
||||
// Test migration to trigger more timeout scenarios
|
||||
p("Testing migration with %s timeout configuration...", timeoutTest.name)
|
||||
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "migrate",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger migrate action for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
|
||||
// Wait for migration to complete
|
||||
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Migrate action failed for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
|
||||
p("[FI] Migrate action completed for %s: %s %s", timeoutTest.name, status.Status, actionOutputIfFailed(status))
|
||||
|
||||
// Wait for MIGRATING notification
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
|
||||
}, 60*time.Second)
|
||||
if !found {
|
||||
ef("MIGRATING notification was not received for %s timeout config", timeoutTest.name)
|
||||
}
|
||||
migrateData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MIGRATING notification received for %s: %v", timeoutTest.name, migrateData)
|
||||
|
||||
// do a bind action
|
||||
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "bind",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger bind action for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
status, err = faultInjector.WaitForAction(ctx, bindResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Bind action failed for %s: %v", timeoutTest.name, err)
|
||||
}
|
||||
p("[FI] Bind action completed for %s: %s %s", timeoutTest.name, status.Status, actionOutputIfFailed(status))
|
||||
|
||||
// waiting for moving notification
|
||||
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
|
||||
}, 3*time.Minute)
|
||||
if !found {
|
||||
ef("MOVING notification was not received for %s timeout config", timeoutTest.name)
|
||||
}
|
||||
|
||||
movingData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MOVING notification received for %s. %v", timeoutTest.name, movingData)
|
||||
|
||||
// Continue traffic for post-handoff timeout observation
|
||||
p("Continuing traffic for %v to observe post-handoff timeout behavior...", 1*time.Minute)
|
||||
time.Sleep(1 * time.Minute)
|
||||
|
||||
commandsRunner.Stop()
|
||||
testDuration := time.Since(testStartTime)
|
||||
|
||||
// Analyze timeout behavior
|
||||
trackerAnalysis := tracker.GetAnalysis()
|
||||
logAnalysis := logCollector.GetAnalysis()
|
||||
if trackerAnalysis.NotificationProcessingErrors > 0 {
|
||||
e("Notification processing errors with %s timeout config: %d", timeoutTest.name, trackerAnalysis.NotificationProcessingErrors)
|
||||
}
|
||||
|
||||
// Validate timeout-specific behavior
|
||||
switch timeoutTest.name {
|
||||
case "Conservative":
|
||||
if trackerAnalysis.UnrelaxedTimeoutCount > trackerAnalysis.RelaxedTimeoutCount {
|
||||
e("Conservative config should have more relaxed than unrelaxed timeouts, got relaxed=%d, unrelaxed=%d",
|
||||
trackerAnalysis.RelaxedTimeoutCount, trackerAnalysis.UnrelaxedTimeoutCount)
|
||||
}
|
||||
case "Aggressive":
|
||||
// Aggressive timeouts should complete faster
|
||||
if testDuration > 5*time.Minute {
|
||||
e("Aggressive config took too long: %v", testDuration)
|
||||
}
|
||||
if logAnalysis.TotalHandoffRetries > logAnalysis.TotalHandoffCount {
|
||||
e("Expect handoff retries since aggressive timeouts are shorter, got %d retries for %d handoffs",
|
||||
logAnalysis.TotalHandoffRetries, logAnalysis.TotalHandoffCount)
|
||||
}
|
||||
case "HighLatency":
|
||||
// High latency config should have very few unrelaxed after moving
|
||||
if logAnalysis.UnrelaxedAfterMoving > 2 {
|
||||
e("High latency config should have minimal unrelaxed timeouts after moving, got %d", logAnalysis.UnrelaxedAfterMoving)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate we received expected notifications
|
||||
if trackerAnalysis.FailingOverCount == 0 {
|
||||
e("Expected FAILING_OVER notifications with %s timeout config, got none", timeoutTest.name)
|
||||
}
|
||||
if trackerAnalysis.FailedOverCount == 0 {
|
||||
e("Expected FAILED_OVER notifications with %s timeout config, got none", timeoutTest.name)
|
||||
}
|
||||
if trackerAnalysis.MigratingCount == 0 {
|
||||
e("Expected MIGRATING notifications with %s timeout config, got none", timeoutTest.name)
|
||||
}
|
||||
|
||||
// Validate timeout counts are reasonable
|
||||
if trackerAnalysis.RelaxedTimeoutCount == 0 {
|
||||
e("Expected relaxed timeouts with %s config, got none", timeoutTest.name)
|
||||
}
|
||||
|
||||
if logAnalysis.SucceededHandoffCount == 0 {
|
||||
e("Expected successful handoffs with %s config, got none", timeoutTest.name)
|
||||
}
|
||||
|
||||
if errorsDetected {
|
||||
logCollector.DumpLogs()
|
||||
trackerAnalysis.Print(t)
|
||||
logCollector.Clear()
|
||||
tracker.Clear()
|
||||
ef("[FAIL] Errors detected with %s timeout config", timeoutTest.name)
|
||||
}
|
||||
p("Timeout configuration %s test completed successfully in %v", timeoutTest.name, testDuration)
|
||||
p("Command runner stats:")
|
||||
p("Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
|
||||
p("Relaxed timeouts: %d, Unrelaxed timeouts: %d", trackerAnalysis.RelaxedTimeoutCount, trackerAnalysis.UnrelaxedTimeoutCount)
|
||||
})
|
||||
|
||||
// Clear logs between timeout configuration tests
|
||||
logCollector.ClearLogs()
|
||||
}
|
||||
|
||||
p("All timeout configurations tested successfully")
|
||||
}
|
||||
261
maintnotifications/e2e/scenario_tls_configs_test.go
Normal file
261
maintnotifications/e2e/scenario_tls_configs_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
// TODO ADD TLS CONFIGS
|
||||
// TestTLSConfigurationsPushNotifications tests push notifications with different TLS configurations
|
||||
func ТestTLSConfigurationsPushNotifications(t *testing.T) {
|
||||
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var dump = true
|
||||
var errorsDetected = false
|
||||
var p = func(format string, args ...interface{}) {
|
||||
printLog("TLS-CONFIGS", false, format, args...)
|
||||
}
|
||||
|
||||
var e = func(format string, args ...interface{}) {
|
||||
errorsDetected = true
|
||||
printLog("TLS-CONFIGS", true, format, args...)
|
||||
}
|
||||
|
||||
// Test different TLS configurations
|
||||
// Note: TLS configuration is typically handled at the Redis connection config level
|
||||
// This scenario demonstrates the testing pattern for different TLS setups
|
||||
tlsConfigs := []struct {
|
||||
name string
|
||||
description string
|
||||
skipReason string
|
||||
}{
|
||||
{
|
||||
name: "NoTLS",
|
||||
description: "No TLS encryption (plain text)",
|
||||
},
|
||||
{
|
||||
name: "TLSInsecure",
|
||||
description: "TLS with insecure skip verify (testing only)",
|
||||
},
|
||||
{
|
||||
name: "TLSSecure",
|
||||
description: "Secure TLS with certificate verification",
|
||||
skipReason: "Requires valid certificates in test environment",
|
||||
},
|
||||
{
|
||||
name: "TLSMinimal",
|
||||
description: "TLS with minimal version requirements",
|
||||
},
|
||||
{
|
||||
name: "TLSStrict",
|
||||
description: "Strict TLS with TLS 1.3 and specific cipher suites",
|
||||
},
|
||||
}
|
||||
|
||||
logCollector.ClearLogs()
|
||||
defer func() {
|
||||
logCollector.Clear()
|
||||
}()
|
||||
|
||||
// Test each TLS configuration with its own fresh database
|
||||
for _, tlsTest := range tlsConfigs {
|
||||
t.Run(tlsTest.name, func(t *testing.T) {
|
||||
// Setup: Create fresh database and client factory for THIS TLS config test
|
||||
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
|
||||
defer cleanup()
|
||||
t.Logf("[TLS-CONFIGS-%s] Created test database with bdb_id: %d", tlsTest.name, bdbID)
|
||||
|
||||
// Get endpoint config from factory (now connected to new database)
|
||||
endpointConfig := factory.GetConfig()
|
||||
|
||||
// Create fault injector
|
||||
faultInjector, err := CreateTestFaultInjector()
|
||||
if err != nil {
|
||||
t.Fatalf("[ERROR] Failed to create fault injector: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if dump {
|
||||
p("Pool stats:")
|
||||
factory.PrintPoolStats(t)
|
||||
}
|
||||
}()
|
||||
|
||||
errorsDetected = false
|
||||
var ef = func(format string, args ...interface{}) {
|
||||
printLog("TLS-CONFIGS", true, format, args...)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if tlsTest.skipReason != "" {
|
||||
t.Skipf("Skipping %s: %s", tlsTest.name, tlsTest.skipReason)
|
||||
}
|
||||
|
||||
p("Testing TLS configuration: %s - %s", tlsTest.name, tlsTest.description)
|
||||
|
||||
minIdleConns := 3
|
||||
poolSize := 8
|
||||
maxConnections := 12
|
||||
|
||||
// Create Redis client with specific TLS configuration
|
||||
// Note: TLS configuration is handled at the factory/connection level
|
||||
client, err := factory.Create(fmt.Sprintf("tls-test-%s", tlsTest.name), &CreateClientOptions{
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxActiveConns: maxConnections,
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
HandoffTimeout: 30 * time.Second,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
PostHandoffRelaxedDuration: 2 * time.Second,
|
||||
MaxWorkers: 15,
|
||||
EndpointType: maintnotifications.EndpointTypeExternalIP,
|
||||
},
|
||||
ClientName: fmt.Sprintf("tls-test-%s", tlsTest.name),
|
||||
})
|
||||
if err != nil {
|
||||
// Some TLS configurations might fail in test environments
|
||||
if tlsTest.name == "TLSSecure" || tlsTest.name == "TLSStrict" {
|
||||
t.Skipf("TLS configuration %s failed (expected in test environment): %v", tlsTest.name, err)
|
||||
}
|
||||
ef("Failed to create client for %s: %v", tlsTest.name, err)
|
||||
}
|
||||
|
||||
// Create timeout tracker
|
||||
tracker := NewTrackingNotificationsHook()
|
||||
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
|
||||
setupNotificationHooks(client, tracker, logger)
|
||||
defer func() {
|
||||
tracker.Clear()
|
||||
}()
|
||||
|
||||
// Verify initial connectivity
|
||||
err = client.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
if tlsTest.name == "TLSSecure" || tlsTest.name == "TLSStrict" {
|
||||
t.Skipf("TLS configuration %s ping failed (expected in test environment): %v", tlsTest.name, err)
|
||||
}
|
||||
ef("Failed to ping Redis with %s TLS config: %v", tlsTest.name, err)
|
||||
}
|
||||
|
||||
p("Client connected successfully with %s TLS configuration", tlsTest.name)
|
||||
|
||||
commandsRunner, _ := NewCommandRunner(client)
|
||||
defer func() {
|
||||
if dump {
|
||||
stats := commandsRunner.GetStats()
|
||||
p("%s TLS config stats: Operations: %d, Errors: %d, Timeout Errors: %d",
|
||||
tlsTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
|
||||
}
|
||||
commandsRunner.Stop()
|
||||
}()
|
||||
|
||||
// Start command traffic
|
||||
go func() {
|
||||
commandsRunner.FireCommandsUntilStop(ctx)
|
||||
}()
|
||||
|
||||
// Test migration with this TLS configuration
|
||||
p("Testing migration with %s TLS configuration...", tlsTest.name)
|
||||
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
|
||||
Type: "migrate",
|
||||
Parameters: map[string]interface{}{
|
||||
"bdb_id": endpointConfig.BdbID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ef("Failed to trigger migrate action for %s: %v", tlsTest.name, err)
|
||||
}
|
||||
|
||||
// Wait for MIGRATING notification
|
||||
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
|
||||
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
|
||||
}, 60*time.Second)
|
||||
if !found {
|
||||
ef("MIGRATING notification was not received for %s TLS config", tlsTest.name)
|
||||
}
|
||||
migrateData := logs2.ExtractDataFromLogMessage(match)
|
||||
p("MIGRATING notification received for %s: %v", tlsTest.name, migrateData)
|
||||
|
||||
// Wait for migration to complete
|
||||
status, err := faultInjector.WaitForAction(ctx, migrateResp.ActionID,
|
||||
WithMaxWaitTime(240*time.Second),
|
||||
WithPollInterval(2*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
ef("[FI] Migrate action failed for %s: %v", tlsTest.name, err)
|
||||
}
|
||||
p("[FI] Migrate action completed for %s: %s %s", tlsTest.name, status.Status, actionOutputIfFailed(status))
|
||||
|
||||
// Continue traffic for a bit to observe TLS behavior
|
||||
time.Sleep(5 * time.Second)
|
||||
commandsRunner.Stop()
|
||||
|
||||
// Analyze results for this TLS configuration
|
||||
trackerAnalysis := tracker.GetAnalysis()
|
||||
if trackerAnalysis.NotificationProcessingErrors > 0 {
|
||||
e("Notification processing errors with %s TLS config: %d", tlsTest.name, trackerAnalysis.NotificationProcessingErrors)
|
||||
}
|
||||
|
||||
if trackerAnalysis.UnexpectedNotificationCount > 0 {
|
||||
e("Unexpected notifications with %s TLS config: %d", tlsTest.name, trackerAnalysis.UnexpectedNotificationCount)
|
||||
}
|
||||
|
||||
// Validate we received expected notifications
|
||||
if trackerAnalysis.FailingOverCount == 0 {
|
||||
e("Expected FAILING_OVER notifications with %s TLS config, got none", tlsTest.name)
|
||||
}
|
||||
if trackerAnalysis.FailedOverCount == 0 {
|
||||
e("Expected FAILED_OVER notifications with %s TLS config, got none", tlsTest.name)
|
||||
}
|
||||
if trackerAnalysis.MigratingCount == 0 {
|
||||
e("Expected MIGRATING notifications with %s TLS config, got none", tlsTest.name)
|
||||
}
|
||||
|
||||
if errorsDetected {
|
||||
logCollector.DumpLogs()
|
||||
trackerAnalysis.Print(t)
|
||||
logCollector.Clear()
|
||||
tracker.Clear()
|
||||
ef("[FAIL] Errors detected with %s TLS config", tlsTest.name)
|
||||
}
|
||||
// TLS-specific validations
|
||||
stats := commandsRunner.GetStats()
|
||||
switch tlsTest.name {
|
||||
case "NoTLS":
|
||||
// Plain text should work fine
|
||||
p("Plain text connection processed %d operations", stats.Operations)
|
||||
case "TLSInsecure", "TLSMinimal":
|
||||
// Insecure TLS should work in test environments
|
||||
p("Insecure TLS connection processed %d operations", stats.Operations)
|
||||
if stats.Operations == 0 {
|
||||
e("Expected operations with %s TLS config, got none", tlsTest.name)
|
||||
}
|
||||
case "TLSStrict":
|
||||
// Strict TLS might have different performance characteristics
|
||||
p("Strict TLS connection processed %d operations", stats.Operations)
|
||||
}
|
||||
|
||||
p("TLS configuration %s test completed successfully", tlsTest.name)
|
||||
})
|
||||
|
||||
// Clear logs between TLS configuration tests
|
||||
logCollector.ClearLogs()
|
||||
}
|
||||
|
||||
p("All TLS configurations tested successfully")
|
||||
}
|
||||
213
maintnotifications/e2e/scripts/run-e2e-tests.sh
Executable file
213
maintnotifications/e2e/scripts/run-e2e-tests.sh
Executable file
@@ -0,0 +1,213 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Maintenance Notifications E2E Tests Runner
|
||||
# This script sets up the environment and runs the maintnotifications upgrade E2E tests
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Script directory and repository root
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
|
||||
E2E_DIR="${REPO_ROOT}/maintnotifications/e2e"
|
||||
|
||||
# Configuration
|
||||
FAULT_INJECTOR_URL="http://127.0.0.1:20324"
|
||||
CONFIG_PATH="${REPO_ROOT}/maintnotifications/e2e/infra/cae-client-testing/endpoints.json"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Logging functions
|
||||
log_info() {
|
||||
echo -e "${BLUE}[INFO]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_success() {
|
||||
echo -e "${GREEN}[SUCCESS]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_warning() {
|
||||
echo -e "${YELLOW}[WARNING]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1" >&2
|
||||
}
|
||||
|
||||
# Help function
|
||||
show_help() {
|
||||
cat << EOF
|
||||
Maintenance Notifications E2E Tests Runner
|
||||
|
||||
Usage: $0 [OPTIONS]
|
||||
|
||||
OPTIONS:
|
||||
-h, --help Show this help message
|
||||
-v, --verbose Enable verbose test output
|
||||
-t, --timeout DURATION Test timeout (default: 30m)
|
||||
-r, --run PATTERN Run only tests matching pattern
|
||||
--dry-run Show what would be executed without running
|
||||
--list List available tests
|
||||
--config PATH Override config path (default: infra/cae-client-testing/endpoints.json)
|
||||
--fault-injector URL Override fault injector URL (default: http://127.0.0.1:20324)
|
||||
|
||||
EXAMPLES:
|
||||
$0 # Run all E2E tests
|
||||
$0 -v # Run with verbose output
|
||||
$0 -r TestPushNotifications # Run only push notification tests
|
||||
$0 -t 45m # Run with 45 minute timeout
|
||||
$0 --dry-run # Show what would be executed
|
||||
$0 --list # List available tests
|
||||
|
||||
ENVIRONMENT:
|
||||
The script automatically sets up the required environment variables:
|
||||
- REDIS_ENDPOINTS_CONFIG_PATH: Path to Redis endpoints configuration
|
||||
- FAULT_INJECTION_API_URL: URL of the fault injector server
|
||||
- E2E_SCENARIO_TESTS: Enables scenario tests
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
VERBOSE=""
|
||||
TIMEOUT="30m"
|
||||
RUN_PATTERN=""
|
||||
DRY_RUN=false
|
||||
LIST_TESTS=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
-v|--verbose)
|
||||
VERBOSE="-v"
|
||||
shift
|
||||
;;
|
||||
-t|--timeout)
|
||||
TIMEOUT="$2"
|
||||
shift 2
|
||||
;;
|
||||
-r|--run)
|
||||
RUN_PATTERN="$2"
|
||||
shift 2
|
||||
;;
|
||||
--dry-run)
|
||||
DRY_RUN=true
|
||||
shift
|
||||
;;
|
||||
--list)
|
||||
LIST_TESTS=true
|
||||
shift
|
||||
;;
|
||||
--config)
|
||||
CONFIG_PATH="$2"
|
||||
shift 2
|
||||
;;
|
||||
--fault-injector)
|
||||
FAULT_INJECTOR_URL="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown option: $1"
|
||||
show_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate configuration file exists
|
||||
if [[ ! -f "$CONFIG_PATH" ]]; then
|
||||
log_error "Configuration file not found: $CONFIG_PATH"
|
||||
log_info "Please ensure the endpoints.json file exists at the specified path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set up environment variables
|
||||
export REDIS_ENDPOINTS_CONFIG_PATH="$CONFIG_PATH"
|
||||
export FAULT_INJECTION_API_URL="$FAULT_INJECTOR_URL"
|
||||
export E2E_SCENARIO_TESTS="true"
|
||||
|
||||
# Build test command
|
||||
TEST_CMD="go test -json -tags=e2e"
|
||||
|
||||
if [[ -n "$TIMEOUT" ]]; then
|
||||
TEST_CMD="$TEST_CMD -timeout=$TIMEOUT"
|
||||
fi
|
||||
|
||||
# Note: -v flag is not compatible with -json output format
|
||||
# The -json format already provides verbose test information
|
||||
|
||||
if [[ -n "$RUN_PATTERN" ]]; then
|
||||
TEST_CMD="$TEST_CMD -run $RUN_PATTERN"
|
||||
fi
|
||||
|
||||
TEST_CMD="$TEST_CMD ./maintnotifications/e2e/ "
|
||||
|
||||
# List tests if requested
|
||||
if [[ "$LIST_TESTS" == true ]]; then
|
||||
log_info "Available E2E tests:"
|
||||
cd "$REPO_ROOT"
|
||||
go test -tags=e2e ./maintnotifications/e2e/ -list=. | grep -E "^Test" | sort
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Show configuration
|
||||
log_info "Maintenance notifications E2E Tests Configuration:"
|
||||
echo " Repository Root: $REPO_ROOT" >&2
|
||||
echo " E2E Directory: $E2E_DIR" >&2
|
||||
echo " Config Path: $CONFIG_PATH" >&2
|
||||
echo " Fault Injector URL: $FAULT_INJECTOR_URL" >&2
|
||||
echo " Test Timeout: $TIMEOUT" >&2
|
||||
if [[ -n "$RUN_PATTERN" ]]; then
|
||||
echo " Test Pattern: $RUN_PATTERN" >&2
|
||||
fi
|
||||
echo "" >&2
|
||||
|
||||
# Validate fault injector connectivity
|
||||
log_info "Checking fault injector connectivity..."
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
if curl -s --connect-timeout 5 "$FAULT_INJECTOR_URL/health" >/dev/null 2>&1; then
|
||||
log_success "Fault injector is accessible at $FAULT_INJECTOR_URL"
|
||||
else
|
||||
log_warning "Cannot connect to fault injector at $FAULT_INJECTOR_URL"
|
||||
log_warning "Tests may fail if fault injection is required"
|
||||
fi
|
||||
else
|
||||
log_warning "curl not available, skipping fault injector connectivity check"
|
||||
fi
|
||||
|
||||
# Show what would be executed in dry-run mode
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
log_info "Dry run mode - would execute:"
|
||||
echo " cd $REPO_ROOT" >&2
|
||||
echo " export REDIS_ENDPOINTS_CONFIG_PATH=\"$CONFIG_PATH\"" >&2
|
||||
echo " export FAULT_INJECTION_API_URL=\"$FAULT_INJECTOR_URL\"" >&2
|
||||
echo " export E2E_SCENARIO_TESTS=\"true\"" >&2
|
||||
echo " $TEST_CMD" >&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Change to repository root
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# Run the tests
|
||||
log_info "Starting E2E tests..."
|
||||
log_info "Command: $TEST_CMD"
|
||||
echo "" >&2
|
||||
|
||||
if eval "$TEST_CMD"; then
|
||||
echo "" >&2
|
||||
log_success "All E2E tests completed successfully!"
|
||||
exit 0
|
||||
else
|
||||
echo "" >&2
|
||||
log_error "E2E tests failed!"
|
||||
log_info "Check the test output above for details"
|
||||
exit 1
|
||||
fi
|
||||
76
maintnotifications/e2e/utils_test.go
Normal file
76
maintnotifications/e2e/utils_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
func isTimeout(errMsg string) bool {
|
||||
return contains(errMsg, "i/o timeout") ||
|
||||
contains(errMsg, "deadline exceeded") ||
|
||||
contains(errMsg, "context deadline exceeded")
|
||||
}
|
||||
|
||||
// isTimeoutError checks if an error is a timeout error
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for various timeout error types
|
||||
errStr := err.Error()
|
||||
return isTimeout(errStr)
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) &&
|
||||
(s == substr ||
|
||||
(len(s) > len(substr) &&
|
||||
(s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
containsSubstring(s, substr))))
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func printLog(group string, isError bool, format string, args ...interface{}) {
|
||||
_, filename, line, _ := runtime.Caller(2)
|
||||
filename = filepath.Base(filename)
|
||||
finalFormat := "%s:%d [%s][%s] " + format + "\n"
|
||||
if isError {
|
||||
finalFormat = "%s:%d [%s][%s][ERROR] " + format + "\n"
|
||||
}
|
||||
ts := time.Now().Format("15:04:05.000")
|
||||
args = append([]interface{}{filename, line, ts, group}, args...)
|
||||
fmt.Printf(finalFormat, args...)
|
||||
}
|
||||
|
||||
func actionOutputIfFailed(status *ActionStatusResponse) string {
|
||||
if status.Status != StatusFailed {
|
||||
return ""
|
||||
}
|
||||
if status.Error != nil {
|
||||
return fmt.Sprintf("%v", status.Error)
|
||||
}
|
||||
if status.Output == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%+v", status.Output)
|
||||
}
|
||||
63
maintnotifications/errors.go
Normal file
63
maintnotifications/errors.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// Configuration errors
|
||||
var (
|
||||
ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError())
|
||||
ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError())
|
||||
ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError())
|
||||
ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError())
|
||||
ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError())
|
||||
ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError())
|
||||
ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError())
|
||||
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
|
||||
|
||||
// Configuration validation errors
|
||||
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
|
||||
)
|
||||
|
||||
// Integration errors
|
||||
var (
|
||||
ErrInvalidClient = errors.New(logs.InvalidClientError())
|
||||
)
|
||||
|
||||
// Handoff errors
|
||||
var (
|
||||
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
|
||||
)
|
||||
|
||||
// Notification errors
|
||||
var (
|
||||
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
|
||||
)
|
||||
|
||||
// connection handoff errors
|
||||
var (
|
||||
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoff = errors.New("" + logs.ConnectionMarkedForHandoffErrorMessage)
|
||||
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
|
||||
ErrConnectionInvalidHandoffState = errors.New("" + logs.ConnectionInvalidHandoffStateErrorMessage)
|
||||
)
|
||||
|
||||
// general errors
|
||||
var (
|
||||
ErrShutdown = errors.New(logs.ShutdownError())
|
||||
)
|
||||
|
||||
// circuit breaker errors
|
||||
var (
|
||||
ErrCircuitBreakerOpen = errors.New("" + logs.CircuitBreakerOpenErrorMessage)
|
||||
)
|
||||
|
||||
// circuit breaker configuration errors
|
||||
var (
|
||||
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
|
||||
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
|
||||
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
|
||||
)
|
||||
101
maintnotifications/example_hooks.go
Normal file
101
maintnotifications/example_hooks.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
startTimeKey contextKey = "maint_notif_start_time"
|
||||
)
|
||||
|
||||
// MetricsHook collects metrics about notification processing.
|
||||
type MetricsHook struct {
|
||||
NotificationCounts map[string]int64
|
||||
ProcessingTimes map[string]time.Duration
|
||||
ErrorCounts map[string]int64
|
||||
HandoffCounts int64 // Total handoffs initiated
|
||||
HandoffSuccesses int64 // Successful handoffs
|
||||
HandoffFailures int64 // Failed handoffs
|
||||
}
|
||||
|
||||
// NewMetricsHook creates a new metrics collection hook.
|
||||
func NewMetricsHook() *MetricsHook {
|
||||
return &MetricsHook{
|
||||
NotificationCounts: make(map[string]int64),
|
||||
ProcessingTimes: make(map[string]time.Duration),
|
||||
ErrorCounts: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// PreHook records the start time for processing metrics.
|
||||
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
mh.NotificationCounts[notificationType]++
|
||||
|
||||
// Log connection information if available
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID()))
|
||||
}
|
||||
|
||||
// Store start time in context for duration calculation
|
||||
startTime := time.Now()
|
||||
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
|
||||
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// PostHook records processing completion and any errors.
|
||||
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
// Calculate processing duration
|
||||
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
|
||||
duration := time.Since(startTime)
|
||||
mh.ProcessingTimes[notificationType] = duration
|
||||
}
|
||||
|
||||
// Record errors
|
||||
if result != nil {
|
||||
mh.ErrorCounts[notificationType]++
|
||||
|
||||
// Log error details with connection information
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns a summary of collected metrics.
|
||||
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"notification_counts": mh.NotificationCounts,
|
||||
"processing_times": mh.ProcessingTimes,
|
||||
"error_counts": mh.ErrorCounts,
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
|
||||
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
|
||||
// Get circuit breaker statistics
|
||||
stats := poolHook.GetCircuitBreakerStats()
|
||||
|
||||
for _, stat := range stats {
|
||||
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
|
||||
fmt.Printf(" State: %s\n", stat.State)
|
||||
fmt.Printf(" Failures: %d\n", stat.Failures)
|
||||
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
|
||||
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
|
||||
|
||||
// Alert if circuit breaker is open
|
||||
if stat.State.String() == "open" {
|
||||
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
480
maintnotifications/handoff_worker.go
Normal file
480
maintnotifications/handoff_worker.go
Normal file
@@ -0,0 +1,480 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// handoffWorkerManager manages background workers and queue for connection handoffs
|
||||
type handoffWorkerManager struct {
|
||||
// Event-driven handoff support
|
||||
handoffQueue chan HandoffRequest // Queue for handoff requests
|
||||
shutdown chan struct{} // Shutdown signal
|
||||
shutdownOnce sync.Once // Ensure clean shutdown
|
||||
workerWg sync.WaitGroup // Track worker goroutines
|
||||
|
||||
// On-demand worker management
|
||||
maxWorkers int
|
||||
activeWorkers atomic.Int32
|
||||
workerTimeout time.Duration // How long workers wait for work before exiting
|
||||
workersScaling atomic.Bool
|
||||
|
||||
// Simple state tracking
|
||||
pending sync.Map // map[uint64]int64 (connID -> seqID)
|
||||
|
||||
// Configuration for the maintenance notifications
|
||||
config *Config
|
||||
|
||||
// Pool hook reference for handoff processing
|
||||
poolHook *PoolHook
|
||||
|
||||
// Circuit breaker manager for endpoint failure handling
|
||||
circuitBreakerManager *CircuitBreakerManager
|
||||
}
|
||||
|
||||
// newHandoffWorkerManager creates a new handoff worker manager
|
||||
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
|
||||
return &handoffWorkerManager{
|
||||
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
|
||||
shutdown: make(chan struct{}),
|
||||
maxWorkers: config.MaxWorkers,
|
||||
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
|
||||
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
|
||||
config: config,
|
||||
poolHook: poolHook,
|
||||
circuitBreakerManager: newCircuitBreakerManager(config),
|
||||
}
|
||||
}
|
||||
|
||||
// getCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
|
||||
return int(hwm.activeWorkers.Load())
|
||||
}
|
||||
|
||||
// getPendingMap returns the pending map for testing purposes
|
||||
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
|
||||
return &hwm.pending
|
||||
}
|
||||
|
||||
// getMaxWorkers returns the max workers for testing purposes
|
||||
func (hwm *handoffWorkerManager) getMaxWorkers() int {
|
||||
return hwm.maxWorkers
|
||||
}
|
||||
|
||||
// getHandoffQueue returns the handoff queue for testing purposes
|
||||
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
|
||||
return hwm.handoffQueue
|
||||
}
|
||||
|
||||
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return hwm.circuitBreakerManager.GetAllStats()
|
||||
}
|
||||
|
||||
// resetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
|
||||
hwm.circuitBreakerManager.Reset()
|
||||
}
|
||||
|
||||
// isHandoffPending returns true if the given connection has a pending handoff
|
||||
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
|
||||
_, pending := hwm.pending.Load(conn.GetID())
|
||||
return pending
|
||||
}
|
||||
|
||||
// ensureWorkerAvailable ensures at least one worker is available to process requests
|
||||
// Creates a new worker if needed and under the max limit
|
||||
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return
|
||||
default:
|
||||
if hwm.workersScaling.CompareAndSwap(false, true) {
|
||||
defer hwm.workersScaling.Store(false)
|
||||
// Check if we need a new worker
|
||||
currentWorkers := hwm.activeWorkers.Load()
|
||||
workersWas := currentWorkers
|
||||
for currentWorkers < int32(hwm.maxWorkers) {
|
||||
hwm.workerWg.Add(1)
|
||||
go hwm.onDemandWorker()
|
||||
currentWorkers++
|
||||
}
|
||||
// workersWas is always <= currentWorkers
|
||||
// currentWorkers will be maxWorkers, but if we have a worker that was closed
|
||||
// while we were creating new workers, just add the difference between
|
||||
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
|
||||
hwm.activeWorkers.Add(currentWorkers - workersWas)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onDemandWorker processes handoff requests and exits when idle
|
||||
func (hwm *handoffWorkerManager) onDemandWorker() {
|
||||
defer func() {
|
||||
// Handle panics to ensure proper cleanup
|
||||
if r := recover(); r != nil {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
|
||||
}
|
||||
|
||||
// Decrement active worker count when exiting
|
||||
hwm.activeWorkers.Add(-1)
|
||||
hwm.workerWg.Done()
|
||||
}()
|
||||
|
||||
// Create reusable timer to prevent timer leaks
|
||||
timer := time.NewTimer(hwm.workerTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
// Reset timer for next iteration
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(hwm.workerTimeout)
|
||||
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
// Worker has been idle for too long, exit to save resources
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
|
||||
}
|
||||
return
|
||||
case request := <-hwm.handoffQueue:
|
||||
// Check for shutdown before processing
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
|
||||
}
|
||||
// Clean up the request before exiting
|
||||
hwm.pending.Delete(request.ConnID)
|
||||
return
|
||||
default:
|
||||
// Process the request
|
||||
hwm.processHandoffRequest(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processHandoffRequest processes a single handoff request
|
||||
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
// Remove from pending map
|
||||
defer hwm.pending.Delete(request.Conn.GetID())
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
|
||||
}
|
||||
|
||||
// Create a context with handoff timeout from config
|
||||
handoffTimeout := 15 * time.Second // Default timeout
|
||||
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
|
||||
handoffTimeout = hwm.config.HandoffTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a context that also respects the shutdown signal
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Monitor shutdown signal in a separate goroutine
|
||||
go func() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
shutdownCancel()
|
||||
case <-shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform the handoff with cancellable context
|
||||
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
|
||||
minRetryBackoff := 500 * time.Millisecond
|
||||
if err != nil {
|
||||
if shouldRetry {
|
||||
now := time.Now()
|
||||
deadline, ok := shutdownCtx.Deadline()
|
||||
thirdOfTimeout := handoffTimeout / 3
|
||||
if !ok || deadline.Before(now) {
|
||||
// wait half the timeout before retrying if no deadline or deadline has passed
|
||||
deadline = now.Add(thirdOfTimeout)
|
||||
}
|
||||
afterTime := deadline.Sub(now)
|
||||
if afterTime < minRetryBackoff {
|
||||
afterTime = minRetryBackoff
|
||||
}
|
||||
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
// Get current retry count for better logging
|
||||
currentRetries := request.Conn.HandoffRetries()
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
|
||||
}
|
||||
time.AfterFunc(afterTime, func() {
|
||||
if err := hwm.queueHandoff(request.Conn); err != nil {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
|
||||
}
|
||||
hwm.closeConnFromRequest(context.Background(), request, err)
|
||||
}
|
||||
})
|
||||
return
|
||||
} else {
|
||||
go hwm.closeConnFromRequest(ctx, request, err)
|
||||
}
|
||||
|
||||
// Clear handoff state if not returned for retry
|
||||
seqID := request.Conn.GetMovingSeqID()
|
||||
connID := request.Conn.GetID()
|
||||
if hwm.poolHook.operationsManager != nil {
|
||||
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// queueHandoff queues a handoff request for processing
|
||||
// if err is returned, connection will be removed from pool
|
||||
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
|
||||
// Get handoff info atomically to prevent race conditions
|
||||
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
|
||||
// on retries the connection will not be marked for handoff, but it will have retries > 0
|
||||
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
|
||||
if !shouldHandoff && conn.HandoffRetries() == 0 {
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
|
||||
}
|
||||
return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
|
||||
}
|
||||
|
||||
// Create handoff request with atomically retrieved data
|
||||
request := HandoffRequest{
|
||||
Conn: conn,
|
||||
ConnID: conn.GetID(),
|
||||
Endpoint: endpoint,
|
||||
SeqID: seqID,
|
||||
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
|
||||
}
|
||||
|
||||
select {
|
||||
// priority to shutdown
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond): // give workers a chance to process
|
||||
// Queue is full - log and attempt scaling
|
||||
queueLen := len(hwm.handoffQueue)
|
||||
queueCap := cap(hwm.handoffQueue)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have workers available to handle the load
|
||||
hwm.ensureWorkerAvailable()
|
||||
return ErrHandoffQueueFull
|
||||
}
|
||||
|
||||
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
|
||||
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
|
||||
hwm.shutdownOnce.Do(func() {
|
||||
close(hwm.shutdown)
|
||||
// workers will exit when they finish their current request
|
||||
|
||||
// Shutdown circuit breaker manager cleanup goroutine
|
||||
if hwm.circuitBreakerManager != nil {
|
||||
hwm.circuitBreakerManager.Shutdown()
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for workers to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hwm.workerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// performConnectionHandoff performs the actual connection handoff
|
||||
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
|
||||
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
|
||||
// Clear handoff state after successful handoff
|
||||
connID := conn.GetID()
|
||||
|
||||
newEndpoint := conn.GetHandoffEndpoint()
|
||||
if newEndpoint == "" {
|
||||
return false, ErrConnectionInvalidHandoffState
|
||||
}
|
||||
|
||||
// Use circuit breaker to protect against failing endpoints
|
||||
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
|
||||
|
||||
// Check if circuit breaker is open before attempting handoff
|
||||
if circuitBreaker.IsOpen() {
|
||||
internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
|
||||
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
|
||||
}
|
||||
|
||||
// Perform the handoff
|
||||
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
|
||||
|
||||
// Update circuit breaker based on result
|
||||
if err != nil {
|
||||
// Only track dial/network errors in circuit breaker, not initialization errors
|
||||
if shouldRetry {
|
||||
circuitBreaker.recordFailure()
|
||||
}
|
||||
return shouldRetry, err
|
||||
}
|
||||
|
||||
// Success - record in circuit breaker
|
||||
circuitBreaker.recordSuccess()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
|
||||
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
|
||||
|
||||
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
|
||||
if retries > maxRetries {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
|
||||
}
|
||||
// won't retry on ErrMaxHandoffRetriesReached
|
||||
return false, ErrMaxHandoffRetriesReached
|
||||
}
|
||||
|
||||
// Create endpoint-specific dialer
|
||||
endpointDialer := hwm.createEndpointDialer(newEndpoint)
|
||||
|
||||
// Create new connection to the new endpoint
|
||||
newNetConn, err := endpointDialer(ctx)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
|
||||
// will retry
|
||||
// Maybe a network error - retry after a delay
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Get the old connection
|
||||
oldConn := conn.GetNetConn()
|
||||
|
||||
// Apply relaxed timeout to the new connection for the configured post-handoff duration
|
||||
// This gives the new connection more time to handle operations during cluster transition
|
||||
// Setting this here (before initing the connection) ensures that the connection is going
|
||||
// to use the relaxed timeout for the first operation (auth/ACL select)
|
||||
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
|
||||
relaxedTimeout := hwm.config.RelaxedTimeout
|
||||
// Set relaxed timeout with deadline - no background goroutine needed
|
||||
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
|
||||
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
|
||||
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the connection and execute initialization
|
||||
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
|
||||
if err != nil {
|
||||
// won't retry
|
||||
// Initialization failed - remove the connection
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if oldConn != nil {
|
||||
oldConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn.ClearHandoffState()
|
||||
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// createEndpointDialer creates a dialer function that connects to a specific endpoint
|
||||
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Parse endpoint to extract host and port
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
// If no port specified, assume default Redis port
|
||||
host = endpoint
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
}
|
||||
|
||||
// Use the base dialer to connect to the new endpoint
|
||||
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnFromRequest closes the connection and logs the reason
|
||||
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
|
||||
pooler := request.Pool
|
||||
conn := request.Conn
|
||||
if pooler != nil {
|
||||
pooler.Remove(ctx, conn, err)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
|
||||
}
|
||||
} else {
|
||||
conn.Close()
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
|
||||
}
|
||||
}
|
||||
}
|
||||
60
maintnotifications/hooks.go
Normal file
60
maintnotifications/hooks.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// LoggingHook is an example hook implementation that logs all notifications.
|
||||
type LoggingHook struct {
|
||||
LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug
|
||||
}
|
||||
|
||||
// PreHook logs the notification before processing and allows modification.
|
||||
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
if lh.LogLevel >= 2 { // Info level
|
||||
// Log the notification type and content
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
seqID := int64(0)
|
||||
if slices.Contains(maintenanceNotificationTypes, notificationType) {
|
||||
// seqID is the second element in the notification array
|
||||
if len(notification) > 1 {
|
||||
if parsedSeqID, ok := notification[1].(int64); !ok {
|
||||
seqID = 0
|
||||
} else {
|
||||
seqID = parsedSeqID
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification))
|
||||
}
|
||||
return notification, true // Continue processing with unmodified notification
|
||||
}
|
||||
|
||||
// PostHook logs the result after processing.
|
||||
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
if result != nil && lh.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification))
|
||||
} else if lh.LogLevel >= 3 { // Debug level
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType))
|
||||
}
|
||||
}
|
||||
|
||||
// NewLoggingHook creates a new logging hook with the specified log level.
|
||||
// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug
|
||||
func NewLoggingHook(logLevel int) *LoggingHook {
|
||||
return &LoggingHook{LogLevel: logLevel}
|
||||
}
|
||||
320
maintnotifications/manager.go
Normal file
320
maintnotifications/manager.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Push notification type constants for maintenance
|
||||
const (
|
||||
NotificationMoving = "MOVING"
|
||||
NotificationMigrating = "MIGRATING"
|
||||
NotificationMigrated = "MIGRATED"
|
||||
NotificationFailingOver = "FAILING_OVER"
|
||||
NotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// maintenanceNotificationTypes contains all notification types that maintenance handles
|
||||
var maintenanceNotificationTypes = []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
// NotificationHook is called before and after notification processing
|
||||
// PreHook can modify the notification and return false to skip processing
|
||||
// PostHook is called after successful processing
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
|
||||
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
|
||||
}
|
||||
|
||||
// MovingOperationKey provides a unique key for tracking MOVING operations
|
||||
// that combines sequence ID with connection identifier to handle duplicate
|
||||
// sequence IDs across multiple connections to the same node.
|
||||
type MovingOperationKey struct {
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
ConnID uint64 // Unique connection identifier
|
||||
}
|
||||
|
||||
// String returns a string representation of the key for debugging
|
||||
func (k MovingOperationKey) String() string {
|
||||
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
|
||||
}
|
||||
|
||||
// Manager provides a simplified upgrade functionality with hooks and atomic state.
|
||||
type Manager struct {
|
||||
client interfaces.ClientInterface
|
||||
config *Config
|
||||
options interfaces.OptionsInterface
|
||||
pool pool.Pooler
|
||||
|
||||
// MOVING operation tracking - using sync.Map for better concurrent performance
|
||||
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
|
||||
|
||||
// Atomic state tracking - no locks needed for state queries
|
||||
activeOperationCount atomic.Int64 // Number of active operations
|
||||
closed atomic.Bool // Manager closed state
|
||||
|
||||
// Notification hooks for extensibility
|
||||
hooks []NotificationHook
|
||||
hooksMu sync.RWMutex // Protects hooks slice
|
||||
poolHooksRef *PoolHook
|
||||
}
|
||||
|
||||
// MovingOperation tracks an active MOVING operation.
|
||||
type MovingOperation struct {
|
||||
SeqID int64
|
||||
NewEndpoint string
|
||||
StartTime time.Time
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new simplified manager.
|
||||
func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
|
||||
if client == nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
hm := &Manager{
|
||||
client: client,
|
||||
pool: pool,
|
||||
options: client.GetOptions(),
|
||||
config: config.Clone(),
|
||||
hooks: make([]NotificationHook, 0),
|
||||
}
|
||||
|
||||
// Set up push notification handling
|
||||
if err := hm.setupPushNotifications(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// GetPoolHook creates a pool hook with a custom dialer.
|
||||
func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
|
||||
poolHook := hm.createPoolHook(baseDialer)
|
||||
hm.pool.AddPoolHook(poolHook)
|
||||
}
|
||||
|
||||
// setupPushNotifications sets up push notification handling by registering with the client's processor.
|
||||
func (hm *Manager) setupPushNotifications() error {
|
||||
processor := hm.client.GetPushProcessor()
|
||||
if processor == nil {
|
||||
return ErrInvalidClient // Client doesn't support push notifications
|
||||
}
|
||||
|
||||
// Create our notification handler
|
||||
handler := &NotificationHandler{manager: hm, operationsManager: hm}
|
||||
|
||||
// Register handlers for all upgrade notifications with the client's processor
|
||||
for _, notificationType := range maintenanceNotificationTypes {
|
||||
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
|
||||
return errors.New(logs.FailedToRegisterHandler(notificationType, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
|
||||
func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Create MOVING operation record
|
||||
movingOp := &MovingOperation{
|
||||
SeqID: seqID,
|
||||
NewEndpoint: newEndpoint,
|
||||
StartTime: time.Now(),
|
||||
Deadline: deadline,
|
||||
}
|
||||
|
||||
// Use LoadOrStore for atomic check-and-set operation
|
||||
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
|
||||
// Duplicate MOVING notification, ignore
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
|
||||
}
|
||||
|
||||
// Increment active operation count atomically
|
||||
hm.activeOperationCount.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
|
||||
func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Remove from active operations atomically
|
||||
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
|
||||
}
|
||||
// Decrement active operation count only if operation existed
|
||||
hm.activeOperationCount.Add(-1)
|
||||
} else {
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveMovingOperations returns active operations with composite keys.
|
||||
// WARNING: This method creates a new map and copies all operations on every call.
|
||||
// Use sparingly, especially in hot paths or high-frequency logging.
|
||||
func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
|
||||
result := make(map[MovingOperationKey]*MovingOperation)
|
||||
|
||||
// Iterate over sync.Map to build result
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
k := key.(MovingOperationKey)
|
||||
op := value.(*MovingOperation)
|
||||
|
||||
// Create a copy to avoid sharing references
|
||||
result[k] = &MovingOperation{
|
||||
SeqID: op.SeqID,
|
||||
NewEndpoint: op.NewEndpoint,
|
||||
StartTime: op.StartTime,
|
||||
Deadline: op.Deadline,
|
||||
}
|
||||
return true // Continue iteration
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHandoffInProgress returns true if any handoff is in progress.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *Manager) IsHandoffInProgress() bool {
|
||||
return hm.activeOperationCount.Load() > 0
|
||||
}
|
||||
|
||||
// GetActiveOperationCount returns the number of active operations.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *Manager) GetActiveOperationCount() int64 {
|
||||
return hm.activeOperationCount.Load()
|
||||
}
|
||||
|
||||
// Close closes the manager.
|
||||
func (hm *Manager) Close() error {
|
||||
// Use atomic operation for thread-safe close check
|
||||
if !hm.closed.CompareAndSwap(false, true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
// Shutdown the pool hook if it exists
|
||||
if hm.poolHooksRef != nil {
|
||||
// Use a timeout to prevent hanging indefinitely
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := hm.poolHooksRef.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
// was not able to close pool hook, keep closed state false
|
||||
hm.closed.Store(false)
|
||||
return err
|
||||
}
|
||||
// Remove the pool hook from the pool
|
||||
if hm.pool != nil {
|
||||
hm.pool.RemovePoolHook(hm.poolHooksRef)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all active operations
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
hm.activeMovingOps.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Reset counter
|
||||
hm.activeOperationCount.Store(0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetState returns current state using atomic counter for lock-free operation.
|
||||
func (hm *Manager) GetState() State {
|
||||
if hm.activeOperationCount.Load() > 0 {
|
||||
return StateMoving
|
||||
}
|
||||
return StateIdle
|
||||
}
|
||||
|
||||
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
|
||||
func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
currentNotification := notification
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
|
||||
if !shouldContinue {
|
||||
return modifiedNotification, false
|
||||
}
|
||||
currentNotification = modifiedNotification
|
||||
}
|
||||
|
||||
return currentNotification, true
|
||||
}
|
||||
|
||||
// processPostHooks calls all post-hooks with the processing result.
|
||||
func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
|
||||
}
|
||||
}
|
||||
|
||||
// createPoolHook creates a pool hook with this manager already set.
|
||||
func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
|
||||
if hm.poolHooksRef != nil {
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
// Get pool size from client options for better worker defaults
|
||||
poolSize := 0
|
||||
if hm.options != nil {
|
||||
poolSize = hm.options.GetPoolSize()
|
||||
}
|
||||
|
||||
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
|
||||
hm.poolHooksRef.SetPool(hm.pool)
|
||||
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
|
||||
func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
|
||||
hm.hooksMu.Lock()
|
||||
defer hm.hooksMu.Unlock()
|
||||
hm.hooks = append(hm.hooks, notificationHook)
|
||||
}
|
||||
260
maintnotifications/manager_test.go
Normal file
260
maintnotifications/manager_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
)
|
||||
|
||||
// MockClient implements interfaces.ClientInterface for testing
|
||||
type MockClient struct {
|
||||
options interfaces.OptionsInterface
|
||||
}
|
||||
|
||||
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
|
||||
return mc.options
|
||||
}
|
||||
|
||||
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &MockPushProcessor{}
|
||||
}
|
||||
|
||||
// MockPushProcessor implements interfaces.NotificationProcessor for testing
|
||||
type MockPushProcessor struct{}
|
||||
|
||||
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockOptions implements interfaces.OptionsInterface for testing
|
||||
type MockOptions struct{}
|
||||
|
||||
func (mo *MockOptions) GetReadTimeout() time.Duration {
|
||||
return 5 * time.Second
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetWriteTimeout() time.Duration {
|
||||
return 5 * time.Second
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetAddr() string {
|
||||
return "localhost:6379"
|
||||
}
|
||||
|
||||
func (mo *MockOptions) IsTLSEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetProtocol() int {
|
||||
return 3 // RESP3
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetPoolSize() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetNetwork() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRefactoring(t *testing.T) {
|
||||
t.Run("AtomicStateTracking", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create maintnotifications manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// Test initial state
|
||||
if manager.IsHandoffInProgress() {
|
||||
t.Error("Expected no handoff in progress initially")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateIdle {
|
||||
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||
}
|
||||
|
||||
// Add an operation
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation: %v", err)
|
||||
}
|
||||
|
||||
// Test state after adding operation
|
||||
if !manager.IsHandoffInProgress() {
|
||||
t.Error("Expected handoff in progress after adding operation")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 1 {
|
||||
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateMoving {
|
||||
t.Errorf("Expected StateMoving, got %v", manager.GetState())
|
||||
}
|
||||
|
||||
// Remove the operation
|
||||
manager.UntrackOperationWithConnID(12345, 1)
|
||||
|
||||
// Test state after removing operation
|
||||
if manager.IsHandoffInProgress() {
|
||||
t.Error("Expected no handoff in progress after removing operation")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateIdle {
|
||||
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SyncMapPerformance", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create maintnotifications manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
|
||||
// Test concurrent operations
|
||||
const numOps = 100
|
||||
for i := 0; i < numOps; i++ {
|
||||
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != numOps {
|
||||
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
// Test GetActiveMovingOperations
|
||||
operations := manager.GetActiveMovingOperations()
|
||||
if len(operations) != numOps {
|
||||
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
|
||||
}
|
||||
|
||||
// Remove all operations
|
||||
for i := 0; i < numOps; i++ {
|
||||
manager.UntrackOperationWithConnID(int64(i), uint64(i))
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DuplicateOperationHandling", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create maintnotifications manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
|
||||
// Add operation
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation: %v", err)
|
||||
}
|
||||
|
||||
// Try to add duplicate operation
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Duplicate operation should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Should still have only 1 operation
|
||||
if manager.GetActiveOperationCount() != 1 {
|
||||
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotificationTypeConstants", func(t *testing.T) {
|
||||
// Test that constants are properly defined
|
||||
expectedTypes := []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
if len(maintenanceNotificationTypes) != len(expectedTypes) {
|
||||
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(maintenanceNotificationTypes))
|
||||
}
|
||||
|
||||
// Test that all expected types are present
|
||||
typeMap := make(map[string]bool)
|
||||
for _, t := range maintenanceNotificationTypes {
|
||||
typeMap[t] = true
|
||||
}
|
||||
|
||||
for _, expected := range expectedTypes {
|
||||
if !typeMap[expected] {
|
||||
t.Errorf("Expected notification type %s not found in maintenanceNotificationTypes", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that maintenanceNotificationTypes contains all expected constants
|
||||
expectedConstants := []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
for _, expected := range expectedConstants {
|
||||
found := false
|
||||
for _, actual := range maintenanceNotificationTypes {
|
||||
if actual == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected constant %s not found in maintenanceNotificationTypes", expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
180
maintnotifications/pool_hook.go
Normal file
180
maintnotifications/pool_hook.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// OperationsManagerInterface defines the interface for completing handoff operations
|
||||
type OperationsManagerInterface interface {
|
||||
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
|
||||
UntrackOperationWithConnID(seqID int64, connID uint64)
|
||||
}
|
||||
|
||||
// HandoffRequest represents a request to handoff a connection to a new endpoint
|
||||
type HandoffRequest struct {
|
||||
Conn *pool.Conn
|
||||
ConnID uint64 // Unique connection identifier
|
||||
Endpoint string
|
||||
SeqID int64
|
||||
Pool pool.Pooler // Pool to remove connection from on failure
|
||||
}
|
||||
|
||||
// PoolHook implements pool.PoolHook for Redis-specific connection handling
|
||||
// with maintenance notifications support.
|
||||
type PoolHook struct {
|
||||
// Base dialer for creating connections to new endpoints during handoffs
|
||||
// args are network and address
|
||||
baseDialer func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// Network type (e.g., "tcp", "unix")
|
||||
network string
|
||||
|
||||
// Worker manager for background handoff processing
|
||||
workerManager *handoffWorkerManager
|
||||
|
||||
// Configuration for the maintenance notifications
|
||||
config *Config
|
||||
|
||||
// Operations manager interface for operation completion tracking
|
||||
operationsManager OperationsManagerInterface
|
||||
|
||||
// Pool interface for removing connections on handoff failure
|
||||
pool pool.Pooler
|
||||
}
|
||||
|
||||
// NewPoolHook creates a new pool hook
|
||||
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook {
|
||||
return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0)
|
||||
}
|
||||
|
||||
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
|
||||
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook {
|
||||
// Apply defaults if config is nil or has zero values
|
||||
if config == nil {
|
||||
config = config.ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
ph := &PoolHook{
|
||||
// baseDialer is used to create connections to new endpoints during handoffs
|
||||
baseDialer: baseDialer,
|
||||
network: network,
|
||||
config: config,
|
||||
operationsManager: operationsManager,
|
||||
}
|
||||
|
||||
// Create worker manager
|
||||
ph.workerManager = newHandoffWorkerManager(config, ph)
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
// SetPool sets the pool interface for removing connections on handoff failure
|
||||
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
|
||||
ph.pool = pooler
|
||||
}
|
||||
|
||||
// GetCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (ph *PoolHook) GetCurrentWorkers() int {
|
||||
return ph.workerManager.getCurrentWorkers()
|
||||
}
|
||||
|
||||
// IsHandoffPending returns true if the given connection has a pending handoff
|
||||
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
|
||||
return ph.workerManager.isHandoffPending(conn)
|
||||
}
|
||||
|
||||
// GetPendingMap returns the pending map for testing purposes
|
||||
func (ph *PoolHook) GetPendingMap() *sync.Map {
|
||||
return ph.workerManager.getPendingMap()
|
||||
}
|
||||
|
||||
// GetMaxWorkers returns the max workers for testing purposes
|
||||
func (ph *PoolHook) GetMaxWorkers() int {
|
||||
return ph.workerManager.getMaxWorkers()
|
||||
}
|
||||
|
||||
// GetHandoffQueue returns the handoff queue for testing purposes
|
||||
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
|
||||
return ph.workerManager.getHandoffQueue()
|
||||
}
|
||||
|
||||
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return ph.workerManager.getCircuitBreakerStats()
|
||||
}
|
||||
|
||||
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
ph.workerManager.resetCircuitBreakers()
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
|
||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||
// in a handoff state at the moment.
|
||||
|
||||
// Check if connection is usable (not in a handoff state)
|
||||
// Should not happen since the pool will not return a connection that is not usable.
|
||||
if !conn.IsUsable() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||
if conn.ShouldHandoff() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// first check if we should handoff for faster rejection
|
||||
if !conn.ShouldHandoff() {
|
||||
// Default behavior (no handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// check pending handoff to not queue the same connection twice
|
||||
if ph.workerManager.isHandoffPending(conn) {
|
||||
// Default behavior (pending handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := ph.workerManager.queueHandoff(conn); err != nil {
|
||||
// Failed to queue handoff, remove the connection
|
||||
internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err))
|
||||
// Don't pool, remove connection, no error to caller
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
// Check if handoff was already processed by a worker before we can mark it as queued
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was already processed - this is normal and the connection should be pooled
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := conn.MarkQueuedForHandoff(); err != nil {
|
||||
// If marking fails, check if handoff was processed in the meantime
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was processed - this is normal, pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
// Other error - remove the connection
|
||||
return false, true, nil
|
||||
}
|
||||
internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID()))
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||
return ph.workerManager.shutdownWorkers(ctx)
|
||||
}
|
||||
954
maintnotifications/pool_hook_test.go
Normal file
954
maintnotifications/pool_hook_test.go
Normal file
@@ -0,0 +1,954 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// mockNetConn implements net.Conn for testing
|
||||
type mockNetConn struct {
|
||||
addr string
|
||||
shouldFailInit bool
|
||||
}
|
||||
|
||||
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (m *mockNetConn) Close() error { return nil }
|
||||
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
|
||||
// createMockPoolConnection creates a mock pool connection for testing
|
||||
func createMockPoolConnection() *pool.Conn {
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
conn.SetUsable(true) // Make connection usable for testing
|
||||
return conn
|
||||
}
|
||||
|
||||
// mockPool implements pool.Pooler for testing
|
||||
type mockPool struct {
|
||||
removedConnections map[uint64]bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
|
||||
// Not implemented for testing
|
||||
}
|
||||
|
||||
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Use pool.Conn directly - no adapter needed
|
||||
mp.removedConnections[conn.GetID()] = true
|
||||
}
|
||||
|
||||
// WasRemoved safely checks if a connection was removed from the pool
|
||||
func (mp *mockPool) WasRemoved(connID uint64) bool {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
return mp.removedConnections[connID]
|
||||
}
|
||||
|
||||
func (mp *mockPool) Len() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) IdleLen() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) Stats() *pool.Stats {
|
||||
return &pool.Stats{}
|
||||
}
|
||||
|
||||
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
|
||||
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
|
||||
func (mp *mockPool) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnectionHook tests the Redis connection processor functionality
|
||||
func TestConnectionHook(t *testing.T) {
|
||||
// Create a base dialer for testing
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Mode: ModeAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
|
||||
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Verify connection is marked for handoff
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Fatal("Connection should be marked for handoff")
|
||||
}
|
||||
// Set a mock initialization function with synchronization
|
||||
initConnCalled := make(chan bool, 1)
|
||||
proceedWithInit := make(chan bool, 1)
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
select {
|
||||
case initConnCalled <- true:
|
||||
default:
|
||||
}
|
||||
// Wait for test to proceed
|
||||
<-proceedWithInit
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection immediately (handoff queued)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled immediately with event-driven handoff")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when queuing handoff")
|
||||
}
|
||||
|
||||
// Wait for initialization to be called (indicates handoff started)
|
||||
select {
|
||||
case <-initConnCalled:
|
||||
// Good, initialization was called
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Timeout waiting for initialization function to be called")
|
||||
}
|
||||
|
||||
// Connection should be in pending map while initialization is blocked
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending handoffs map")
|
||||
}
|
||||
|
||||
// Allow initialization to proceed
|
||||
proceedWithInit <- true
|
||||
|
||||
// Wait for handoff to complete with proper timeout and polling
|
||||
timeout := time.After(2 * time.Second)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify handoff completed (removed from pending map)
|
||||
if _, pending := processor.GetPendingMap().Load(conn); pending {
|
||||
t.Error("Connection should be removed from pending map after handoff")
|
||||
}
|
||||
|
||||
// Verify connection is usable again
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after successful handoff")
|
||||
}
|
||||
|
||||
// Verify handoff state is cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after completion")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandoffNotNeeded", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
// Don't mark for handoff
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error when handoff not needed: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection normally
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled when no handoff needed")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when no handoff needed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyEndpoint", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error with empty endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection (empty endpoint clears state)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled after clearing empty endpoint")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed after clearing empty endpoint")
|
||||
}
|
||||
|
||||
// State should be cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
|
||||
// Create a failing base dialer
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Mode: ModeAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 2, // Reduced retries for faster test
|
||||
HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test
|
||||
}
|
||||
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not return error to caller: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection initially (handoff queued)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled initially with event-driven handoff")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when queuing handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete and fail with proper timeout and polling
|
||||
timeout := time.After(3 * time.Second)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// wait for handoff to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for failed handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection should be removed from pending map after failed handoff
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending {
|
||||
t.Error("Connection should be removed from pending map after failed handoff")
|
||||
}
|
||||
|
||||
// Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up)
|
||||
// Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete
|
||||
time.Sleep(800 * time.Millisecond)
|
||||
|
||||
// After max retries are reached, the connection should be removed from pool
|
||||
// and handoff state should be cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after max retries reached")
|
||||
}
|
||||
|
||||
t.Logf("EventDrivenHandoffDialerError test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("BufferedDataRESP2", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// For this test, we'll just verify the logic works for connections without buffered data
|
||||
// The actual buffered data detection is handled by the pool's connection health check
|
||||
// which is outside the scope of the Redis connection processor
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection normally (no buffered data in mock)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled when no buffered data")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when no buffered data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGet", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Mode: ModeAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
|
||||
}
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// Simulate a pending handoff by marking for handoff and queuing
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
})
|
||||
|
||||
t.Run("EventDrivenStateManagement", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// Test initial state - no pending handoffs
|
||||
if _, pending := processor.GetPendingMap().Load(conn); pending {
|
||||
t.Error("New connection should not have pending handoffs")
|
||||
}
|
||||
|
||||
// Test adding to pending map
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending map")
|
||||
}
|
||||
|
||||
// Test OnGet with pending handoff
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||
}
|
||||
|
||||
// Test removing from pending map and clearing handoff state
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
if _, pending := processor.GetPendingMap().Load(conn); pending {
|
||||
t.Error("Connection should be removed from pending map")
|
||||
}
|
||||
|
||||
// Clear handoff state to simulate completed handoff
|
||||
conn.ClearHandoffState()
|
||||
conn.SetUsable(true) // Make connection usable again
|
||||
|
||||
// Test OnGet without pending handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||
// Create processor with small queue to test optimization features
|
||||
config := &Config{
|
||||
MaxWorkers: 3,
|
||||
HandoffQueueSize: 2,
|
||||
MaxHandoffRetries: 3, // Small queue to trigger optimizations
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Add small delay to simulate network latency
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create multiple connections that need handoff to fill the queue
|
||||
connections := make([]*pool.Conn, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
connections[i] = createMockPoolConnection()
|
||||
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
|
||||
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
|
||||
}
|
||||
// Set a mock initialization function
|
||||
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
successCount := 0
|
||||
|
||||
// Process connections - should trigger scaling and timeout logic
|
||||
for _, conn := range connections {
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Logf("OnPut returned error (expected with timeout): %v", err)
|
||||
}
|
||||
|
||||
if shouldPool && !shouldRemove {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
// With timeout and scaling, most handoffs should eventually succeed
|
||||
if successCount == 0 {
|
||||
t.Error("Should have queued some handoffs with timeout and scaling")
|
||||
}
|
||||
|
||||
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
|
||||
|
||||
// Give time for workers to process and scaling to occur
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("WorkerScalingBehavior", func(t *testing.T) {
|
||||
// Create processor with small queue to test scaling behavior
|
||||
config := &Config{
|
||||
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
|
||||
HandoffQueueSize: 1,
|
||||
MaxHandoffRetries: 3, // Very small queue to force scaling
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Verify initial worker count (should be 0 with on-demand workers)
|
||||
if processor.GetCurrentWorkers() != 0 {
|
||||
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
|
||||
}
|
||||
if processor.GetMaxWorkers() != 15 {
|
||||
t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers())
|
||||
}
|
||||
|
||||
// The on-demand worker behavior creates workers only when needed
|
||||
// This test just verifies the basic configuration is correct
|
||||
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
|
||||
processor.GetMaxWorkers(), processor.GetCurrentWorkers())
|
||||
})
|
||||
|
||||
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
|
||||
// Create processor with fast post-handoff duration for testing
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Allow retries for successful handoff
|
||||
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
|
||||
RelaxedTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a connection and trigger handoff
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("Handoff should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete with proper timeout and polling
|
||||
timeout := time.After(1 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify relaxed timeout is set with deadline
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should have relaxed timeout after handoff")
|
||||
}
|
||||
|
||||
// Test that timeout is still active before deadline
|
||||
// We'll use HasRelaxedTimeout which internally checks the deadline
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should still have active relaxed timeout before deadline")
|
||||
}
|
||||
|
||||
// Wait for deadline to pass
|
||||
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
|
||||
|
||||
// Test that timeout is automatically restored after deadline
|
||||
// HasRelaxedTimeout should return false after deadline passes
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should not have active relaxed timeout after deadline")
|
||||
}
|
||||
|
||||
// Additional verification: calling HasRelaxedTimeout again should still return false
|
||||
// and should have cleared the internal timeout values
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should not have relaxed timeout after deadline (second check)")
|
||||
}
|
||||
|
||||
t.Logf("Passive timeout restoration test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("UsableFlagBehavior", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new connection without setting it usable
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
|
||||
// Initially, connection should not be usable (not initialized)
|
||||
if conn.IsUsable() {
|
||||
t.Error("New connection should not be usable before initialization")
|
||||
}
|
||||
|
||||
// Simulate initialization by setting usable to true
|
||||
conn.SetUsable(true)
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after initialization")
|
||||
}
|
||||
|
||||
// OnGet should succeed for usable connection
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||
}
|
||||
|
||||
// Mark connection for handoff
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Connection should still be usable until queued, but marked for handoff
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should still be usable after being marked for handoff (until queued)")
|
||||
}
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Error("Connection should be marked for handoff")
|
||||
}
|
||||
|
||||
// OnGet should fail for connection marked for handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("OnGet should fail for connection marked for handoff")
|
||||
}
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// After handoff completion, connection should be usable again
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after handoff completion")
|
||||
}
|
||||
|
||||
// OnGet should succeed again
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Usable flag behavior test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("StaticQueueBehavior", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 3,
|
||||
HandoffQueueSize: 50,
|
||||
MaxHandoffRetries: 3, // Explicit static queue size
|
||||
}
|
||||
|
||||
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Verify queue capacity matches configured size
|
||||
queueCapacity := cap(processor.GetHandoffQueue())
|
||||
if queueCapacity != 50 {
|
||||
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
|
||||
}
|
||||
|
||||
// Test that queue size is static regardless of pool size
|
||||
// (No dynamic resizing should occur)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill part of the queue
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
|
||||
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
|
||||
}
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to queue handoff %d: %v", i, err)
|
||||
}
|
||||
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
|
||||
i, shouldPool, shouldRemove)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify queue capacity remains static (the main purpose of this test)
|
||||
finalCapacity := cap(processor.GetHandoffQueue())
|
||||
|
||||
if finalCapacity != 50 {
|
||||
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
|
||||
}
|
||||
|
||||
// Note: We don't check queue size here because workers process items quickly
|
||||
// The important thing is that the capacity remains static regardless of pool size
|
||||
})
|
||||
|
||||
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
|
||||
// Create a failing dialer that will cause handoff initialization to fail
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Return a connection that will fail during initialization
|
||||
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a mock pool that tracks removals
|
||||
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
|
||||
processor.SetPool(mockPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a connection and mark it for handoff
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a failing initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return fmt.Errorf("initialization failed")
|
||||
})
|
||||
|
||||
// Process the connection - handoff should fail and connection should be removed
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after failed handoff attempt")
|
||||
}
|
||||
|
||||
// Wait for handoff to be attempted and fail
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify that the connection was removed from the pool
|
||||
if !mockPool.WasRemoved(conn.GetID()) {
|
||||
t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID())
|
||||
}
|
||||
|
||||
t.Logf("Connection removal on handoff failure test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
|
||||
// Create config with short post-handoff duration for testing
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Allow retries for successful handoff
|
||||
RelaxedTimeout: 5 * time.Second,
|
||||
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("OnPut failed: %v", err)
|
||||
}
|
||||
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled after successful handoff")
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed after successful handoff")
|
||||
}
|
||||
|
||||
// Wait for the handoff to complete (it happens asynchronously)
|
||||
timeout := time.After(1 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that relaxed timeout was applied to the new connection
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("New connection should have relaxed timeout applied after handoff")
|
||||
}
|
||||
|
||||
// Wait for the post-handoff duration to expire
|
||||
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
|
||||
|
||||
// Verify that relaxed timeout was automatically cleared
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// First mark should succeed
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("First MarkForHandoff should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Second mark should fail
|
||||
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
|
||||
t.Fatal("Second MarkForHandoff should return error")
|
||||
} else if err.Error() != "connection is already marked for handoff" {
|
||||
t.Fatalf("Expected specific error message, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify original handoff data is preserved
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Fatal("Connection should still be marked for handoff")
|
||||
}
|
||||
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
|
||||
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
|
||||
}
|
||||
if conn.GetMovingSeqID() != 1 {
|
||||
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
|
||||
// Test that HandoffTimeout from config is actually used
|
||||
customTimeout := 2 * time.Second
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
HandoffTimeout: customTimeout, // Custom timeout
|
||||
MaxHandoffRetries: 1, // Single retry to speed up test
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a connection that will test the timeout
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a dialer that will check the context timeout
|
||||
var timeoutVerified int32 // Use atomic for thread safety
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
// Check that the context has the expected timeout
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("Context should have a deadline")
|
||||
return errors.New("no deadline")
|
||||
}
|
||||
|
||||
// The deadline should be approximately customTimeout from now
|
||||
expectedDeadline := time.Now().Add(customTimeout)
|
||||
timeDiff := deadline.Sub(expectedDeadline)
|
||||
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
|
||||
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
|
||||
expectedDeadline, deadline, timeDiff)
|
||||
} else {
|
||||
atomic.StoreInt32(&timeoutVerified, 1)
|
||||
}
|
||||
|
||||
return nil // Successful handoff
|
||||
})
|
||||
|
||||
// Trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Connection should be queued for handoff
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Errorf("Connection should be pooled for handoff processing")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if atomic.LoadInt32(&timeoutVerified) == 0 {
|
||||
t.Error("HandoffTimeout was not properly applied to context")
|
||||
}
|
||||
|
||||
t.Logf("HandoffTimeout configuration test completed successfully")
|
||||
})
|
||||
}
|
||||
282
maintnotifications/push_notification_handler.go
Normal file
282
maintnotifications/push_notification_handler.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// NotificationHandler handles push notifications for the simplified manager.
|
||||
type NotificationHandler struct {
|
||||
manager *Manager
|
||||
operationsManager OperationsManagerInterface
|
||||
}
|
||||
|
||||
// HandlePushNotification processes push notifications with hook support.
|
||||
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Process pre-hooks - they can modify the notification or skip processing
|
||||
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
|
||||
if !shouldContinue {
|
||||
return nil // Hooks decided to skip processing
|
||||
}
|
||||
|
||||
var err error
|
||||
switch notificationType {
|
||||
case NotificationMoving:
|
||||
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrating:
|
||||
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrated:
|
||||
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailingOver:
|
||||
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailedOver:
|
||||
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
|
||||
default:
|
||||
// Ignore other notification types (e.g., pub/sub messages)
|
||||
err = nil
|
||||
}
|
||||
|
||||
// Process post-hooks with the result
|
||||
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMoving processes MOVING notifications.
|
||||
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
|
||||
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
seqID, ok := notification[1].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Extract timeS
|
||||
timeS, ok := notification[2].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
newEndpoint := ""
|
||||
if len(notification) > 3 {
|
||||
// Extract new endpoint
|
||||
newEndpoint, ok = notification[3].(string)
|
||||
if !ok {
|
||||
stringified := fmt.Sprintf("%v", notification[3])
|
||||
// this could be <nil> which is valid
|
||||
if notification[3] == nil || stringified == internal.RedisNull {
|
||||
newEndpoint = ""
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the connection that received this notification
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to get the underlying pool connection
|
||||
var poolConn *pool.Conn
|
||||
if pc, ok := conn.(*pool.Conn); ok {
|
||||
poolConn = pc
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// If the connection is closed or not pooled, we can ignore the notification
|
||||
// this connection won't be remembered by the pool and will be garbage collected
|
||||
// Keep pubsub connections around since they are not pooled but are long-lived
|
||||
// and should be allowed to handoff (the pubsub instance will reconnect and change
|
||||
// the underlying *pool.Conn)
|
||||
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
|
||||
return nil
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
|
||||
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
|
||||
if newEndpoint == "" || newEndpoint == internal.RedisNull {
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2))
|
||||
}
|
||||
// same as current endpoint
|
||||
newEndpoint = snh.manager.options.GetAddr()
|
||||
// delay the handoff for timeS/2 seconds to the same endpoint
|
||||
// do this in a goroutine to avoid blocking the notification handler
|
||||
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
|
||||
// and there should be no possibility of a race condition or double handoff.
|
||||
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
|
||||
if poolConn == nil || poolConn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
|
||||
// Log error but don't fail the goroutine - use background context since original may be cancelled
|
||||
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err))
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
|
||||
}
|
||||
|
||||
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
|
||||
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
|
||||
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err))
|
||||
// Connection is already marked for handoff, which is acceptable
|
||||
// This can happen if multiple MOVING notifications are received for the same connection
|
||||
return nil
|
||||
}
|
||||
// Optionally track in m
|
||||
if snh.operationsManager != nil {
|
||||
connID := conn.GetID()
|
||||
// Track the operation (ignore errors since this is optional)
|
||||
_ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
|
||||
} else {
|
||||
return errors.New(logs.ManagerNotInitialized())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrating processes MIGRATING notifications.
|
||||
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATING notifications indicate that a connection is about to be migrated
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout))
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrated processes MIGRATED notifications.
|
||||
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATED notifications indicate that a connection migration has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOver processes FAILING_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILING_OVER notifications indicate that a connection is about to failover
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout))
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOver processes FAILED_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILED_OVER notifications indicate that a connection failover has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
24
maintnotifications/state.go
Normal file
24
maintnotifications/state.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package maintnotifications
|
||||
|
||||
// State represents the current state of a maintenance operation
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateIdle indicates no upgrade is in progress
|
||||
StateIdle State = iota
|
||||
|
||||
// StateHandoff indicates a connection handoff is in progress
|
||||
StateMoving
|
||||
)
|
||||
|
||||
// String returns a string representation of the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateIdle:
|
||||
return "idle"
|
||||
case StateMoving:
|
||||
return "moving"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
146
options.go
146
options.go
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Limiter is the interface of a rate limiter or a circuit breaker.
|
||||
@@ -109,6 +112,16 @@ type Options struct {
|
||||
// default: 5 seconds
|
||||
DialTimeout time.Duration
|
||||
|
||||
// DialerRetries is the maximum number of retry attempts when dialing fails.
|
||||
//
|
||||
// default: 5
|
||||
DialerRetries int
|
||||
|
||||
// DialerRetryTimeout is the backoff duration between retry attempts.
|
||||
//
|
||||
// default: 100 milliseconds
|
||||
DialerRetryTimeout time.Duration
|
||||
|
||||
// ReadTimeout for socket reads. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
//
|
||||
@@ -152,6 +165,7 @@ type Options struct {
|
||||
//
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
// default: false
|
||||
PoolFIFO bool
|
||||
|
||||
// PoolSize is the base number of socket connections.
|
||||
@@ -232,10 +246,24 @@ type Options struct {
|
||||
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
|
||||
UnstableResp3 bool
|
||||
|
||||
// Push notifications are always enabled for RESP3 connections (Protocol: 3)
|
||||
// and are not available for RESP2 connections. No configuration option is needed.
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *Options) init() {
|
||||
@@ -255,6 +283,12 @@ func (opt *Options) init() {
|
||||
if opt.DialTimeout == 0 {
|
||||
opt.DialTimeout = 5 * time.Second
|
||||
}
|
||||
if opt.DialerRetries == 0 {
|
||||
opt.DialerRetries = 5
|
||||
}
|
||||
if opt.DialerRetryTimeout == 0 {
|
||||
opt.DialerRetryTimeout = 100 * time.Millisecond
|
||||
}
|
||||
if opt.Dialer == nil {
|
||||
opt.Dialer = NewDialer(opt)
|
||||
}
|
||||
@@ -312,13 +346,36 @@ func (opt *Options) init() {
|
||||
case 0:
|
||||
opt.MaxRetryBackoff = 512 * time.Millisecond
|
||||
}
|
||||
|
||||
opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
|
||||
|
||||
// auto-detect endpoint type if not specified
|
||||
endpointType := opt.MaintNotificationsConfig.EndpointType
|
||||
if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto {
|
||||
// Auto-detect endpoint type if not specified
|
||||
endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
|
||||
}
|
||||
opt.MaintNotificationsConfig.EndpointType = endpointType
|
||||
}
|
||||
|
||||
func (opt *Options) clone() *Options {
|
||||
clone := *opt
|
||||
|
||||
// Deep clone MaintNotificationsConfig to avoid sharing between clients
|
||||
if opt.MaintNotificationsConfig != nil {
|
||||
configClone := *opt.MaintNotificationsConfig
|
||||
clone.MaintNotificationsConfig = &configClone
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
|
||||
return NewDialer(opt)
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
||||
@@ -666,21 +723,84 @@ func getUserPassword(u *url.URL) (string, string) {
|
||||
func newConnPool(
|
||||
opt *Options,
|
||||
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) *pool.ConnPool {
|
||||
) (*pool.ConnPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer(ctx, opt.Network, opt.Addr)
|
||||
},
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: opt.PoolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: opt.MinIdleConns,
|
||||
MaxIdleConns: opt.MaxIdleConns,
|
||||
MaxActiveConns: opt.MaxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
})
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) (*pool.PubSubPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewPubSubPool(&pool.Options{
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}, dialer), nil
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +40,7 @@ type ClusterOptions struct {
|
||||
ClientName string
|
||||
|
||||
// NewClient creates a cluster node client with provided name and options.
|
||||
// If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications.
|
||||
NewClient func(opt *Options) *Client
|
||||
|
||||
// The maximum number of retries before giving up. Command is retried
|
||||
@@ -125,10 +128,22 @@ type ClusterOptions struct {
|
||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||
UnstableResp3 bool
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
// The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) init() {
|
||||
@@ -385,6 +400,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// Clone MaintNotificationsConfig to avoid sharing between cluster node clients
|
||||
var maintNotificationsConfig *maintnotifications.Config
|
||||
if opt.MaintNotificationsConfig != nil {
|
||||
configClone := *opt.MaintNotificationsConfig
|
||||
maintNotificationsConfig = &configClone
|
||||
}
|
||||
|
||||
return &Options{
|
||||
ClientName: opt.ClientName,
|
||||
Dialer: opt.Dialer,
|
||||
@@ -426,8 +448,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// much use for ClusterSlots config). This means we cannot execute the
|
||||
// READONLY command against that node -- setting readOnly to false in such
|
||||
// situations in the options below will prevent that from happening.
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
MaintNotificationsConfig: maintNotificationsConfig,
|
||||
PushNotificationProcessor: opt.PushNotificationProcessor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1730,7 +1754,7 @@ func (c *ClusterClient) processTxPipelineNode(
|
||||
}
|
||||
|
||||
func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
|
||||
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
|
||||
) error {
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
@@ -1748,7 +1772,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
trimmedCmds := cmds[1 : len(cmds)-1]
|
||||
|
||||
if err := c.txPipelineReadQueued(
|
||||
ctx, rd, statusCmd, trimmedCmds, failedCmds,
|
||||
ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds,
|
||||
); err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
|
||||
@@ -1760,30 +1784,56 @@ func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
return err
|
||||
}
|
||||
|
||||
return pipelineReadCmds(rd, trimmedCmds)
|
||||
return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClusterClient) txPipelineReadQueued(
|
||||
ctx context.Context,
|
||||
node *clusterNode,
|
||||
cn *pool.Conn,
|
||||
rd *proto.Reader,
|
||||
statusCmd *StatusCmd,
|
||||
cmds []Cmder,
|
||||
failedCmds *cmdsMap,
|
||||
) error {
|
||||
// Parse queued replies.
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
if err := statusCmd.readReply(rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, cmd := range cmds {
|
||||
err := statusCmd.readReply(rd)
|
||||
if err == nil || c.checkMovedErr(ctx, cmd, err, failedCmds) || isRedisError(err) {
|
||||
continue
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
err := statusCmd.readReply(rd)
|
||||
if err != nil {
|
||||
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
|
||||
// will be processed later
|
||||
continue
|
||||
}
|
||||
cmd.SetErr(err)
|
||||
if !isRedisError(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
// Parse number of replies.
|
||||
line, err := rd.ReadLine()
|
||||
if err != nil {
|
||||
@@ -1889,12 +1939,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
|
||||
return err
|
||||
}
|
||||
|
||||
// maintenance notifications won't work here for now
|
||||
func (c *ClusterClient) pubSub() *PubSub {
|
||||
var node *clusterNode
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt.clientOptions(),
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
if node != nil {
|
||||
panic("node != nil")
|
||||
}
|
||||
@@ -1928,18 +1978,25 @@ func (c *ClusterClient) pubSub() *PubSub {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cn, err := node.Client.newConn(context.TODO())
|
||||
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
|
||||
if err != nil {
|
||||
node = nil
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// will return nil if already initialized
|
||||
err = node.Client.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
node = nil
|
||||
return nil, err
|
||||
}
|
||||
node.Client.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
err := node.Client.connPool.CloseConn(cn)
|
||||
// Untrack connection from PubSubPool
|
||||
node.Client.pubSubPool.UntrackConn(cn)
|
||||
err := cn.Close()
|
||||
node = nil
|
||||
return err
|
||||
},
|
||||
|
||||
12
pipeline.go
12
pipeline.go
@@ -30,9 +30,12 @@ type Pipeliner interface {
|
||||
// If a certain Redis command is not yet supported, you can use Do to execute it.
|
||||
Do(ctx context.Context, args ...interface{}) *Cmd
|
||||
|
||||
// Process puts the commands to be executed into the pipeline buffer.
|
||||
// Process queues the cmd for later execution.
|
||||
Process(ctx context.Context, cmd Cmder) error
|
||||
|
||||
// BatchProcess adds multiple commands to be executed into the pipeline buffer.
|
||||
BatchProcess(ctx context.Context, cmd ...Cmder) error
|
||||
|
||||
// Discard discards all commands in the pipeline buffer that have not yet been executed.
|
||||
Discard()
|
||||
|
||||
@@ -79,7 +82,12 @@ func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
|
||||
|
||||
// Process queues the cmd for later execution.
|
||||
func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error {
|
||||
c.cmds = append(c.cmds, cmd)
|
||||
return c.BatchProcess(ctx, cmd)
|
||||
}
|
||||
|
||||
// BatchProcess queues multiple cmds for later execution.
|
||||
func (c *Pipeline) BatchProcess(ctx context.Context, cmd ...Cmder) error {
|
||||
c.cmds = append(c.cmds, cmd...)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +60,39 @@ var _ = Describe("pipelining", func() {
|
||||
Expect(cmds).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("pipeline: basic exec", func() {
|
||||
p := client.Pipeline()
|
||||
p.Get(ctx, "key")
|
||||
p.Set(ctx, "key", "value", 0)
|
||||
p.Get(ctx, "key")
|
||||
cmds, err := p.Exec(ctx)
|
||||
Expect(err).To(Equal(redis.Nil))
|
||||
Expect(cmds).To(HaveLen(3))
|
||||
Expect(cmds[0].Err()).To(Equal(redis.Nil))
|
||||
Expect(cmds[1].(*redis.StatusCmd).Val()).To(Equal("OK"))
|
||||
Expect(cmds[1].Err()).NotTo(HaveOccurred())
|
||||
Expect(cmds[2].(*redis.StringCmd).Val()).To(Equal("value"))
|
||||
Expect(cmds[2].Err()).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("pipeline: exec pipeline when get conn failed", func() {
|
||||
p := client.Pipeline()
|
||||
p.Get(ctx, "key")
|
||||
p.Set(ctx, "key", "value", 0)
|
||||
p.Get(ctx, "key")
|
||||
|
||||
client.Close()
|
||||
|
||||
cmds, err := p.Exec(ctx)
|
||||
Expect(err).To(Equal(redis.ErrClosed))
|
||||
Expect(cmds).To(HaveLen(3))
|
||||
for _, cmd := range cmds {
|
||||
Expect(cmd.Err()).To(Equal(redis.ErrClosed))
|
||||
}
|
||||
|
||||
client = redis.NewClient(redisOptions())
|
||||
})
|
||||
|
||||
assertPipeline := func() {
|
||||
It("returns no errors when there are no commands", func() {
|
||||
_, err := pipe.Exec(ctx)
|
||||
@@ -114,6 +147,25 @@ var _ = Describe("pipelining", func() {
|
||||
err := pipe.Do(ctx).Err()
|
||||
Expect(err).To(Equal(errors.New("redis: please enter the command to be executed")))
|
||||
})
|
||||
|
||||
It("should process", func() {
|
||||
err := pipe.Process(ctx, redis.NewCmd(ctx, "asking"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(pipe.Cmds()).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("should batchProcess", func() {
|
||||
err := pipe.BatchProcess(ctx, redis.NewCmd(ctx, "asking"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(pipe.Cmds()).To(HaveLen(1))
|
||||
|
||||
pipe.Discard()
|
||||
Expect(pipe.Cmds()).To(HaveLen(0))
|
||||
|
||||
err = pipe.BatchProcess(ctx, redis.NewCmd(ctx, "asking"), redis.NewCmd(ctx, "set", "key", "value"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(pipe.Cmds()).To(HaveLen(2))
|
||||
})
|
||||
}
|
||||
|
||||
Describe("Pipeline", func() {
|
||||
|
||||
375
pool_pubsub_bench_test.go
Normal file
375
pool_pubsub_bench_test.go
Normal file
@@ -0,0 +1,375 @@
|
||||
// Pool and PubSub Benchmark Suite
|
||||
//
|
||||
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
|
||||
// It's designed to be run against different branches to compare performance.
|
||||
//
|
||||
// Usage Examples:
|
||||
// # Run all benchmarks
|
||||
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run only pool benchmarks
|
||||
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run only PubSub benchmarks
|
||||
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Compare between branches
|
||||
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
|
||||
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
|
||||
// benchcmp branch1.txt branch2.txt
|
||||
//
|
||||
// # Run with memory profiling
|
||||
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run with CPU profiling
|
||||
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// dummyDialer creates a mock connection for benchmarking
|
||||
func dummyDialer(ctx context.Context) (net.Conn, error) {
|
||||
return &dummyConn{}, nil
|
||||
}
|
||||
|
||||
// dummyConn implements net.Conn for benchmarking
|
||||
type dummyConn struct{}
|
||||
|
||||
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Close() error { return nil }
|
||||
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
|
||||
func (c *dummyConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
|
||||
}
|
||||
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// =============================================================================
|
||||
// POOL BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
|
||||
func BenchmarkPoolGetPut(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
MinIdleConns: int32(0), // Start with no idle connections
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
|
||||
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
configs := []struct {
|
||||
poolSize int
|
||||
minIdleConns int
|
||||
}{
|
||||
{8, 2},
|
||||
{16, 4},
|
||||
{32, 8},
|
||||
{64, 16},
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(config.poolSize),
|
||||
MinIdleConns: int32(config.minIdleConns),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
|
||||
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(32),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
MinIdleConns: int32(0),
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Test with different levels of concurrency
|
||||
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||
b.SetParallelism(concurrency)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// PUBSUB BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// benchmarkClient creates a Redis client for benchmarking with mock dialer
|
||||
func benchmarkClient(poolSize int) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379", // Mock address
|
||||
DialTimeout: time.Second,
|
||||
ReadTimeout: time.Second,
|
||||
WriteTimeout: time.Second,
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
|
||||
func BenchmarkPubSubCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 4, 8, 16, 32}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
client := benchmarkClient(poolSize)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
|
||||
func BenchmarkPubSubPatternCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 4, 8, 16, 32}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
client := benchmarkClient(poolSize)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.PSubscribe(ctx, "test-*")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
|
||||
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
concurrencyLevels := []int{1, 2, 4, 8, 16}
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, concurrency)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
|
||||
func BenchmarkPubSubMultipleChannels(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
channelCounts := []int{1, 5, 10, 25, 50, 100}
|
||||
|
||||
for _, channelCount := range channelCounts {
|
||||
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
|
||||
// Prepare channel names
|
||||
channels := make([]string, channelCount)
|
||||
for i := 0; i < channelCount; i++ {
|
||||
channels[i] = fmt.Sprintf("channel-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.Subscribe(ctx, channels...)
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
|
||||
func BenchmarkPubSubReuse(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Benchmark just the creation and closing of PubSub connections
|
||||
// This simulates reuse patterns without requiring actual Redis operations
|
||||
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
|
||||
pubsub.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// COMBINED BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
|
||||
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// Mix of pool stats collection and PubSub creation
|
||||
if pb.Next() {
|
||||
// Pool stats operation
|
||||
stats := client.PoolStats()
|
||||
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
|
||||
}
|
||||
|
||||
if pb.Next() {
|
||||
// PubSub operation
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
|
||||
func BenchmarkPoolStatsCollection(b *testing.B) {
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stats := client.PoolStats()
|
||||
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolHighContention tests pool performance under high contention
|
||||
func BenchmarkPoolHighContention(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// High contention Get/Put operations
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user