mirror of
https://github.com/redis/go-redis.git
synced 2025-11-26 06:23:09 +03:00
fix(conn): conn to have state machine (#3559)
* wip * wip, used and unusable states * polish state machine * correct handling OnPut * better errors for tests, hook should work now * fix linter * improve reauth state management. fix tests * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * better timeouts * empty endpoint handoff case * fix handoff state when queued for handoff * try to detect the deadlock * try to detect the deadlock x2 * delete should be called * improve tests * fix mark on uninitialized connection * Update internal/pool/conn_state_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn_state_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/pool.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn_state.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix error from copilot * address copilot comment * fix(pool): pool performance (#3565) * perf(pool): replace hookManager RWMutex with atomic.Pointer and add predefined state slices - Replace hookManager RWMutex with atomic.Pointer for lock-free reads in hot paths - Add predefined state slices to avoid allocations (validFromInUse, validFromCreatedOrIdle, etc.) - Add Clone() method to PoolHookManager for atomic updates - Update AddPoolHook/RemovePoolHook to use copy-on-write pattern - Update all hookManager access points to use atomic Load() Performance improvements: - Eliminates RWMutex contention in Get/Put/Remove hot paths - Reduces allocations by reusing predefined state slices - Lock-free reads allow better CPU cache utilization * perf(pool): eliminate mutex overhead in state machine hot path The state machine was calling notifyWaiters() on EVERY Get/Put operation, which acquired a mutex even when no waiters were present (the common case). Fix: Use atomic waiterCount to check for waiters BEFORE acquiring mutex. This eliminates mutex contention in the hot path (Get/Put operations). Implementation: - Added atomic.Int32 waiterCount field to ConnStateMachine - Increment when adding waiter, decrement when removing - Check waiterCount atomically before acquiring mutex in notifyWaiters() Performance impact: - Before: mutex lock/unlock on every Get/Put (even with no waiters) - After: lock-free atomic check, only acquire mutex if waiters exist - Expected improvement: ~30-50% for Get/Put operations * perf(pool): use predefined state slices to eliminate allocations in hot path The pool was creating new slice literals on EVERY Get/Put operation: - popIdle(): []ConnState{StateCreated, StateIdle} - putConn(): []ConnState{StateInUse} - CompareAndSwapUsed(): []ConnState{StateIdle} and []ConnState{StateInUse} - MarkUnusableForHandoff(): []ConnState{StateInUse, StateIdle, StateCreated} These allocations were happening millions of times per second in the hot path. Fix: Use predefined global slices defined in conn_state.go: - validFromInUse - validFromCreatedOrIdle - validFromCreatedInUseOrIdle Performance impact: - Before: 4 slice allocations per Get/Put cycle - After: 0 allocations (use predefined slices) - Expected improvement: ~30-40% reduction in allocations and GC pressure * perf(pool): optimize TryTransition to reduce atomic operations Further optimize the hot path by: 1. Remove redundant GetState() call in the loop 2. Only check waiterCount after successful CAS (not before loop) 3. Inline the waiterCount check to avoid notifyWaiters() call overhead This reduces atomic operations from 4-5 per Get/Put to 2-3: - Before: GetState() + CAS + waiterCount.Load() + notifyWaiters mutex check - After: CAS + waiterCount.Load() (only if CAS succeeds) Performance impact: - Eliminates 1-2 atomic operations per Get/Put - Expected improvement: ~10-15% for Get/Put operations * perf(pool): add fast path for Get/Put to match master performance Introduced TryTransitionFast() for the hot path (Get/Put operations): - Single CAS operation (same as master's atomic bool) - No waiter notification overhead - No loop through valid states - No error allocation Hot path flow: 1. popIdle(): Try IDLE → IN_USE (fast), fallback to CREATED → IN_USE 2. putConn(): Try IN_USE → IDLE (fast) This matches master's performance while preserving state machine for: - Background operations (handoff/reauth use UNUSABLE state) - State validation (TryTransition still available) - Waiter notification (AwaitAndTransition for blocking) Performance comparison per Get/Put cycle: - Master: 2 atomic CAS operations - State machine (before): 5 atomic operations (2.5x slower) - State machine (after): 2 atomic CAS operations (same as master!) Expected improvement: Restore to baseline ~11,373 ops/sec * combine cas * fix linter * try faster approach * fast semaphore * better inlining for hot path * fix linter issues * use new semaphore in auth as well * linter should be happy now * add comments * Update internal/pool/conn_state.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * address comment * slight reordering * try to cache time if for non-critical calculation * fix wrong benchmark * add concurrent test * fix benchmark report * add additional expect to check output * comment and variable rename --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * initConn sets IDLE state - Handle unexpected conn state changes * fix precision of time cache and usedAt * allow e2e tests to run longer * Fix broken initialization of idle connections * optimize push notif * 100ms -> 50ms * use correct timer for last health check * verify pass auth on conn creation * fix assertion * fix unsafe test * fix benchmark test * improve remove conn * re doesn't support requirepass * wait more in e2e test * flaky test * add missed method in interface * fix test assertions * silence logs and faster hooks manager * address linter comment * fix flaky test * use read instad of control * use pool size for semsize * CAS instead of reading the state * preallocate errors and states * preallocate state slices * fix flaky test * fix fast semaphore that could have been starved * try to fix the semaphore * should properly notify the waiters - this way a waiter that timesout at the same time a releaser is releasing, won't throw token. the releaser will fail to notify and will pick another waiter. this hybrid approach should be faster than channels and maintains FIFO * waiter may double-release (if closed/times out) * priority of operations * use simple approach of fifo waiters * use simple channel based semaphores * address linter and tests * remove unused benchs * change log message * address pr comments * address pr comments * fix data race --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ type CommandRunnerStats struct {
|
||||
|
||||
// CommandRunner provides utilities for running commands during tests
|
||||
type CommandRunner struct {
|
||||
executing atomic.Bool
|
||||
client redis.UniversalClient
|
||||
stopCh chan struct{}
|
||||
operationCount atomic.Int64
|
||||
@@ -56,6 +57,10 @@ func (cr *CommandRunner) Close() {
|
||||
|
||||
// FireCommandsUntilStop runs commands continuously until stop signal
|
||||
func (cr *CommandRunner) FireCommandsUntilStop(ctx context.Context) {
|
||||
if !cr.executing.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer cr.executing.Store(false)
|
||||
fmt.Printf("[CR] Starting command runner...\n")
|
||||
defer fmt.Printf("[CR] Command runner stopped\n")
|
||||
// High frequency for timeout testing
|
||||
|
||||
@@ -319,6 +319,7 @@ func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis
|
||||
}
|
||||
|
||||
var client redis.UniversalClient
|
||||
var opts interface{}
|
||||
|
||||
// Determine if this is a cluster configuration
|
||||
if len(cf.config.Endpoints) > 1 || cf.isClusterEndpoint() {
|
||||
@@ -349,6 +350,7 @@ func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis
|
||||
}
|
||||
}
|
||||
|
||||
opts = clusterOptions
|
||||
client = redis.NewClusterClient(clusterOptions)
|
||||
} else {
|
||||
// Create single client
|
||||
@@ -379,9 +381,14 @@ func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis
|
||||
}
|
||||
}
|
||||
|
||||
opts = clientOptions
|
||||
client = redis.NewClient(clientOptions)
|
||||
}
|
||||
|
||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w\nOptions: %+v", err, opts)
|
||||
}
|
||||
|
||||
// Store the client
|
||||
cf.clients[key] = client
|
||||
|
||||
@@ -832,7 +839,6 @@ func (m *TestDatabaseManager) DeleteDatabase(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to trigger database deletion: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Wait for deletion to complete
|
||||
status, err := m.faultInjector.WaitForAction(ctx, resp.ActionID,
|
||||
WithMaxWaitTime(2*time.Minute),
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
@@ -12,6 +13,8 @@ import (
|
||||
// Global log collector
|
||||
var logCollector *TestLogCollector
|
||||
|
||||
const defaultTestTimeout = 30 * time.Minute
|
||||
|
||||
// Global fault injector client
|
||||
var faultInjector *FaultInjectorClient
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestEndpointTypesPushNotifications(t *testing.T) {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dump = true
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestPushNotifications(t *testing.T) {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Setup: Create fresh database and client factory for this test
|
||||
@@ -297,12 +297,6 @@ func TestPushNotifications(t *testing.T) {
|
||||
// once moving is received, start a second client commands runner
|
||||
p("Starting commands on second client")
|
||||
go commandsRunner2.FireCommandsUntilStop(ctx)
|
||||
defer func() {
|
||||
// stop the second runner
|
||||
commandsRunner2.Stop()
|
||||
// destroy the second client
|
||||
factory.Destroy("push-notification-client-2")
|
||||
}()
|
||||
|
||||
p("Waiting for MOVING notification on second client")
|
||||
matchNotif, fnd := tracker2.FindOrWaitForNotification("MOVING", 3*time.Minute)
|
||||
@@ -393,10 +387,15 @@ func TestPushNotifications(t *testing.T) {
|
||||
|
||||
p("MOVING notification test completed successfully")
|
||||
|
||||
p("Executing commands and collecting logs for analysis... This will take 30 seconds...")
|
||||
p("Executing commands and collecting logs for analysis... ")
|
||||
go commandsRunner.FireCommandsUntilStop(ctx)
|
||||
time.Sleep(30 * time.Second)
|
||||
go commandsRunner2.FireCommandsUntilStop(ctx)
|
||||
go commandsRunner3.FireCommandsUntilStop(ctx)
|
||||
time.Sleep(2 * time.Minute)
|
||||
commandsRunner.Stop()
|
||||
commandsRunner2.Stop()
|
||||
commandsRunner3.Stop()
|
||||
time.Sleep(1 * time.Minute)
|
||||
allLogsAnalysis := logCollector.GetAnalysis()
|
||||
trackerAnalysis := tracker.GetAnalysis()
|
||||
|
||||
@@ -437,33 +436,35 @@ func TestPushNotifications(t *testing.T) {
|
||||
e("No logs found for connection %d", connID)
|
||||
}
|
||||
}
|
||||
// checks are tracker >= logs since the tracker only tracks client1
|
||||
// logs include all clients (and some of them start logging even before all hooks are setup)
|
||||
// for example for idle connections if they receive a notification before the hook is setup
|
||||
// the action (i.e. relaxing timeouts) will be logged, but the notification will not be tracked and maybe wont be logged
|
||||
|
||||
// validate number of notifications in tracker matches number of notifications in logs
|
||||
// allow for more moving in the logs since we started a second client
|
||||
if trackerAnalysis.TotalNotifications > allLogsAnalysis.TotalNotifications {
|
||||
e("Expected %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications)
|
||||
e("Expected at least %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications)
|
||||
}
|
||||
|
||||
// and per type
|
||||
// allow for more moving in the logs since we started a second client
|
||||
if trackerAnalysis.MovingCount > allLogsAnalysis.MovingCount {
|
||||
e("Expected %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount)
|
||||
e("Expected at least %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.MigratingCount != allLogsAnalysis.MigratingCount {
|
||||
e("Expected %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount)
|
||||
if trackerAnalysis.MigratingCount > allLogsAnalysis.MigratingCount {
|
||||
e("Expected at least %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.MigratedCount != allLogsAnalysis.MigratedCount {
|
||||
e("Expected %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount)
|
||||
if trackerAnalysis.MigratedCount > allLogsAnalysis.MigratedCount {
|
||||
e("Expected at least %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.FailingOverCount != allLogsAnalysis.FailingOverCount {
|
||||
e("Expected %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount)
|
||||
if trackerAnalysis.FailingOverCount > allLogsAnalysis.FailingOverCount {
|
||||
e("Expected at least %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.FailedOverCount != allLogsAnalysis.FailedOverCount {
|
||||
e("Expected %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount)
|
||||
if trackerAnalysis.FailedOverCount > allLogsAnalysis.FailedOverCount {
|
||||
e("Expected at least %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount)
|
||||
}
|
||||
|
||||
if trackerAnalysis.UnexpectedNotificationCount != allLogsAnalysis.UnexpectedCount {
|
||||
@@ -471,11 +472,11 @@ func TestPushNotifications(t *testing.T) {
|
||||
}
|
||||
|
||||
// unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it
|
||||
if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving {
|
||||
e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount)
|
||||
if trackerAnalysis.UnrelaxedTimeoutCount > allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving {
|
||||
e("Expected at least %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving)
|
||||
}
|
||||
if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount {
|
||||
e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount)
|
||||
if trackerAnalysis.RelaxedTimeoutCount > allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount {
|
||||
e("Expected at least %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount)
|
||||
}
|
||||
|
||||
// validate all handoffs succeeded
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestStressPushNotifications(t *testing.T) {
|
||||
t.Skip("[STRESS][SKIP] Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 35*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Setup: Create fresh database and client factory for this test
|
||||
|
||||
@@ -20,7 +20,7 @@ func ТestTLSConfigurationsPushNotifications(t *testing.T) {
|
||||
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dump = true
|
||||
|
||||
@@ -18,21 +18,26 @@ var (
|
||||
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
|
||||
|
||||
// Configuration validation errors
|
||||
|
||||
// ErrInvalidHandoffRetries is returned when the number of handoff retries is invalid
|
||||
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
|
||||
)
|
||||
|
||||
// Integration errors
|
||||
var (
|
||||
// ErrInvalidClient is returned when the client does not support push notifications
|
||||
ErrInvalidClient = errors.New(logs.InvalidClientError())
|
||||
)
|
||||
|
||||
// Handoff errors
|
||||
var (
|
||||
// ErrHandoffQueueFull is returned when the handoff queue is full
|
||||
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
|
||||
)
|
||||
|
||||
// Notification errors
|
||||
var (
|
||||
// ErrInvalidNotification is returned when a notification is in an invalid format
|
||||
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
|
||||
)
|
||||
|
||||
@@ -40,24 +45,32 @@ var (
|
||||
var (
|
||||
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoff = errors.New("" + logs.ConnectionMarkedForHandoffErrorMessage)
|
||||
ErrConnectionMarkedForHandoff = errors.New(logs.ConnectionMarkedForHandoffErrorMessage)
|
||||
// ErrConnectionMarkedForHandoffWithState is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoffWithState = errors.New(logs.ConnectionMarkedForHandoffErrorMessage + " with state")
|
||||
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
|
||||
ErrConnectionInvalidHandoffState = errors.New("" + logs.ConnectionInvalidHandoffStateErrorMessage)
|
||||
ErrConnectionInvalidHandoffState = errors.New(logs.ConnectionInvalidHandoffStateErrorMessage)
|
||||
)
|
||||
|
||||
// general errors
|
||||
// shutdown errors
|
||||
var (
|
||||
// ErrShutdown is returned when the maintnotifications manager is shutdown
|
||||
ErrShutdown = errors.New(logs.ShutdownError())
|
||||
)
|
||||
|
||||
// circuit breaker errors
|
||||
var (
|
||||
ErrCircuitBreakerOpen = errors.New("" + logs.CircuitBreakerOpenErrorMessage)
|
||||
// ErrCircuitBreakerOpen is returned when the circuit breaker is open
|
||||
ErrCircuitBreakerOpen = errors.New(logs.CircuitBreakerOpenErrorMessage)
|
||||
)
|
||||
|
||||
// circuit breaker configuration errors
|
||||
var (
|
||||
// ErrInvalidCircuitBreakerFailureThreshold is returned when the circuit breaker failure threshold is invalid
|
||||
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
|
||||
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
|
||||
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
|
||||
// ErrInvalidCircuitBreakerResetTimeout is returned when the circuit breaker reset timeout is invalid
|
||||
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
|
||||
// ErrInvalidCircuitBreakerMaxRequests is returned when the circuit breaker max requests is invalid
|
||||
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
|
||||
)
|
||||
|
||||
@@ -175,8 +175,6 @@ func (hwm *handoffWorkerManager) onDemandWorker() {
|
||||
|
||||
// processHandoffRequest processes a single handoff request
|
||||
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
// Remove from pending map
|
||||
defer hwm.pending.Delete(request.Conn.GetID())
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
|
||||
}
|
||||
@@ -228,16 +226,24 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
|
||||
}
|
||||
// Schedule retry - keep connection in pending map until retry is queued
|
||||
time.AfterFunc(afterTime, func() {
|
||||
if err := hwm.queueHandoff(request.Conn); err != nil {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
|
||||
}
|
||||
// Failed to queue retry - remove from pending and close connection
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
hwm.closeConnFromRequest(context.Background(), request, err)
|
||||
} else {
|
||||
// Successfully queued retry - remove from pending (will be re-added by queueHandoff)
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
}
|
||||
})
|
||||
return
|
||||
} else {
|
||||
// Won't retry - remove from pending and close connection
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
go hwm.closeConnFromRequest(ctx, request, err)
|
||||
}
|
||||
|
||||
@@ -247,6 +253,9 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
if hwm.poolHook.operationsManager != nil {
|
||||
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
|
||||
}
|
||||
} else {
|
||||
// Success - remove from pending map
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,6 +264,7 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
|
||||
// Get handoff info atomically to prevent race conditions
|
||||
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
|
||||
|
||||
// on retries the connection will not be marked for handoff, but it will have retries > 0
|
||||
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
|
||||
if !shouldHandoff && conn.HandoffRetries() == 0 {
|
||||
@@ -446,6 +456,8 @@ func (hwm *handoffWorkerManager) performHandoffInternal(
|
||||
// - set the connection as usable again
|
||||
// - clear the handoff state (shouldHandoff, endpoint, seqID)
|
||||
// - reset the handoff retries to 0
|
||||
// Note: Theoretically there may be a short window where the connection is in the pool
|
||||
// and IDLE (initConn completed) but still has handoff state set.
|
||||
conn.ClearHandoffState()
|
||||
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
|
||||
|
||||
@@ -475,8 +487,16 @@ func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(cont
|
||||
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
|
||||
pooler := request.Pool
|
||||
conn := request.Conn
|
||||
|
||||
// Clear handoff state before closing
|
||||
conn.ClearHandoffState()
|
||||
|
||||
if pooler != nil {
|
||||
pooler.Remove(ctx, conn, err)
|
||||
// Use RemoveWithoutTurn instead of Remove to avoid freeing a turn that we don't have.
|
||||
// The handoff worker doesn't call Get(), so it doesn't have a turn to free.
|
||||
// Remove() is meant to be called after Get() and frees a turn.
|
||||
// RemoveWithoutTurn() removes and closes the connection without affecting the queue.
|
||||
pooler.RemoveWithoutTurn(ctx, conn, err)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
|
||||
}
|
||||
|
||||
@@ -117,17 +117,15 @@ func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||
// in a handoff state at the moment.
|
||||
|
||||
// Check if connection is usable (not in a handoff state)
|
||||
// Should not happen since the pool will not return a connection that is not usable.
|
||||
if !conn.IsUsable() {
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
// Check if connection is marked for handoff
|
||||
// This prevents using connections that have received MOVING notifications
|
||||
if conn.ShouldHandoff() {
|
||||
return false, ErrConnectionMarkedForHandoffWithState
|
||||
}
|
||||
|
||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||
if conn.ShouldHandoff() {
|
||||
// Check if connection is usable (not in UNUSABLE or CLOSED state)
|
||||
// This ensures we don't return connections that are currently being handed off or re-authenticated.
|
||||
if !conn.IsUsable() {
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,9 @@ func (m *mockAddr) String() string { return m.addr }
|
||||
func createMockPoolConnection() *pool.Conn {
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
conn.SetUsable(true) // Make connection usable for testing
|
||||
conn.SetUsable(true) // Make connection usable for testing (transitions to IDLE)
|
||||
// Simulate real flow: connection is acquired (IDLE → IN_USE) before OnPut is called
|
||||
conn.SetUsed(true) // Transition to IN_USE state
|
||||
return conn
|
||||
}
|
||||
|
||||
@@ -73,6 +75,11 @@ func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
|
||||
mp.removedConnections[conn.GetID()] = true
|
||||
}
|
||||
|
||||
func (mp *mockPool) RemoveWithoutTurn(ctx context.Context, conn *pool.Conn, reason error) {
|
||||
// For mock pool, same behavior as Remove since we don't have a turn-based queue
|
||||
mp.Remove(ctx, conn, reason)
|
||||
}
|
||||
|
||||
// WasRemoved safely checks if a connection was removed from the pool
|
||||
func (mp *mockPool) WasRemoved(connID uint64) bool {
|
||||
mp.mu.Lock()
|
||||
@@ -167,7 +174,7 @@ func TestConnectionHook(t *testing.T) {
|
||||
select {
|
||||
case <-initConnCalled:
|
||||
// Good, initialization was called
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Timeout waiting for initialization function to be called")
|
||||
}
|
||||
|
||||
@@ -231,14 +238,12 @@ func TestConnectionHook(t *testing.T) {
|
||||
t.Error("Connection should not be removed when no handoff needed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyEndpoint", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
@@ -385,10 +390,12 @@ func TestConnectionHook(t *testing.T) {
|
||||
// Simulate a pending handoff by marking for handoff and queuing
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets ShouldHandoff=false, state=UNUSABLE)
|
||||
|
||||
ctx := context.Background()
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
// After MarkQueuedForHandoff, ShouldHandoff() returns false, so we get ErrConnectionMarkedForHandoff
|
||||
// (from IsUsable() check) instead of ErrConnectionMarkedForHandoffWithState
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
@@ -414,7 +421,7 @@ func TestConnectionHook(t *testing.T) {
|
||||
// Test adding to pending map
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets ShouldHandoff=false, state=UNUSABLE)
|
||||
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending map")
|
||||
@@ -423,8 +430,9 @@ func TestConnectionHook(t *testing.T) {
|
||||
// Test OnGet with pending handoff
|
||||
ctx := context.Background()
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
// After MarkQueuedForHandoff, ShouldHandoff() returns false, so we get ErrConnectionMarkedForHandoff
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||
t.Errorf("Should return ErrConnectionMarkedForHandoff for pending connection, got %v", err)
|
||||
}
|
||||
if acceptCon {
|
||||
t.Error("Should not accept connection with pending handoff")
|
||||
@@ -624,19 +632,20 @@ func TestConnectionHook(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new connection without setting it usable
|
||||
// Create a new connection
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
|
||||
// Initially, connection should not be usable (not initialized)
|
||||
if conn.IsUsable() {
|
||||
t.Error("New connection should not be usable before initialization")
|
||||
// New connections in CREATED state are usable (they pass OnGet() before initialization)
|
||||
// The initialization happens AFTER OnGet() in the client code
|
||||
if !conn.IsUsable() {
|
||||
t.Error("New connection should be usable (CREATED state is usable)")
|
||||
}
|
||||
|
||||
// Simulate initialization by setting usable to true
|
||||
conn.SetUsable(true)
|
||||
// Simulate initialization by transitioning to IDLE
|
||||
conn.GetStateMachine().Transition(pool.StateIdle)
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after initialization")
|
||||
t.Error("Connection should be usable after initialization (IDLE state)")
|
||||
}
|
||||
|
||||
// OnGet should succeed for usable connection
|
||||
@@ -667,14 +676,16 @@ func TestConnectionHook(t *testing.T) {
|
||||
t.Error("Connection should be marked for handoff")
|
||||
}
|
||||
|
||||
// OnGet should fail for connection marked for handoff
|
||||
// OnGet should FAIL for connection marked for handoff
|
||||
// Even though the connection is still in a usable state, the metadata indicates
|
||||
// it should be handed off, so we reject it to prevent using a connection that
|
||||
// will be moved to a different endpoint
|
||||
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("OnGet should fail for connection marked for handoff")
|
||||
}
|
||||
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
if err != ErrConnectionMarkedForHandoffWithState {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
|
||||
}
|
||||
if acceptConn {
|
||||
t.Error("Connection should not be accepted when marked for handoff")
|
||||
@@ -686,7 +697,7 @@ func TestConnectionHook(t *testing.T) {
|
||||
t.Errorf("OnPut should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
t.Errorf("Connection should be pooled after handoff (shouldPool=%v, shouldRemove=%v)", shouldPool, shouldRemove)
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
|
||||
Reference in New Issue
Block a user