diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 72664b4d..761c9991 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -593,17 +593,40 @@ func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error { // This is called from OnPut hook, where the connection is typically in IN_USE state. // The pool will preserve the UNUSABLE state and not overwrite it with IDLE. func (cn *Conn) MarkQueuedForHandoff() error { - // Check if marked for handoff - if !cn.ShouldHandoff() { + // Get current handoff state + currentState := cn.handoffStateAtomic.Load() + if currentState == nil { return errors.New("connection was not marked for handoff") } + state := currentState.(*HandoffState) + if !state.ShouldHandoff { + return errors.New("connection was not marked for handoff") + } + + // Create new state with ShouldHandoff=false but preserve endpoint and seqID + // This prevents the connection from being queued multiple times while still + // allowing the worker to access the handoff metadata + newState := &HandoffState{ + ShouldHandoff: false, + Endpoint: state.Endpoint, // Preserve endpoint for handoff processing + SeqID: state.SeqID, // Preserve seqID for handoff processing + } + + // Atomic compare-and-swap to update state + if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) { + // State changed between load and CAS - retry or return error + return errors.New("handoff state changed during marking") + } + // Transition to UNUSABLE from either IN_USE (normal flow) or IDLE (edge cases/tests) // The connection is typically in IN_USE state when OnPut is called (normal Put flow) // But in some edge cases or tests, it might already be in IDLE state // The pool will detect this state change and preserve it (not overwrite with IDLE) _, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle}, StateUnusable) if err != nil { + // Restore the original state if transition fails + cn.handoffStateAtomic.Store(currentState) return fmt.Errorf("failed to mark connection as unusable: %w", err) } return nil diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index be15747c..e042e4c6 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -349,13 +349,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c newEndpoint := conn.GetHandoffEndpoint() if newEndpoint == "" { - // Empty endpoint means handoff to current endpoint (reconnect) - // Use the current connection's remote address - if conn.RemoteAddr() != nil { - newEndpoint = conn.RemoteAddr().String() - } else { - return false, ErrConnectionInvalidHandoffState - } + return false, ErrConnectionInvalidHandoffState } // Use circuit breaker to protect against failing endpoints diff --git a/maintnotifications/pool_hook_test.go b/maintnotifications/pool_hook_test.go index 2a3abb73..b11a8dbf 100644 --- a/maintnotifications/pool_hook_test.go +++ b/maintnotifications/pool_hook_test.go @@ -245,36 +245,35 @@ func TestConnectionHook(t *testing.T) { processor := NewPoolHook(baseDialer, "tcp", config, nil) defer processor.Shutdown(context.Background()) + // Create a mock pool that tracks removals + mockPool := &mockPool{removedConnections: make(map[uint64]bool)} + processor.SetPool(mockPool) + conn := createMockPoolConnection() if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint t.Fatalf("Failed to mark connection for handoff: %v", err) } - // Set a mock initialization function - conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { - return nil - }) - ctx := context.Background() shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) if err != nil { t.Errorf("OnPut should not error with empty endpoint: %v", err) } - // Should pool the connection (empty endpoint triggers handoff to current endpoint) + // Should pool the connection (handoff will be queued and fail in worker) if !shouldPool { t.Error("Connection should be pooled when handoff is queued") } if shouldRemove { - t.Error("Connection should not be removed when handoff is queued") + t.Error("Connection should not be removed immediately") } - // Wait for handoff to complete + // Wait for worker to process and fail time.Sleep(100 * time.Millisecond) - // After handoff completes, state should be cleared - if conn.ShouldHandoff() { - t.Error("Connection should not be marked for handoff after handoff completes") + // Connection should be removed from pool after handoff fails + if !mockPool.WasRemoved(conn.GetID()) { + t.Error("Connection should be removed from pool after empty endpoint handoff fails") } }) @@ -404,12 +403,14 @@ func TestConnectionHook(t *testing.T) { // Simulate a pending handoff by marking for handoff and queuing conn.MarkForHandoff("new-endpoint:6379", 12345) processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID - conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + conn.MarkQueuedForHandoff() // Mark as queued (sets ShouldHandoff=false, state=UNUSABLE) ctx := context.Background() acceptCon, err := processor.OnGet(ctx, conn, false) - if err != ErrConnectionMarkedForHandoffWithState { - t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err) + // After MarkQueuedForHandoff, ShouldHandoff() returns false, so we get ErrConnectionMarkedForHandoff + // (from IsUsable() check) instead of ErrConnectionMarkedForHandoffWithState + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } if acceptCon { t.Error("Connection should not be accepted when marked for handoff") @@ -433,7 +434,7 @@ func TestConnectionHook(t *testing.T) { // Test adding to pending map conn.MarkForHandoff("new-endpoint:6379", 12345) processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID - conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + conn.MarkQueuedForHandoff() // Mark as queued (sets ShouldHandoff=false, state=UNUSABLE) if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { t.Error("Connection should be in pending map") @@ -442,8 +443,9 @@ func TestConnectionHook(t *testing.T) { // Test OnGet with pending handoff ctx := context.Background() acceptCon, err := processor.OnGet(ctx, conn, false) - if err != ErrConnectionMarkedForHandoffWithState { - t.Errorf("Should return ErrConnectionMarkedForHandoffWithState for pending connection, got %v", err) + // After MarkQueuedForHandoff, ShouldHandoff() returns false, so we get ErrConnectionMarkedForHandoff + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Should return ErrConnectionMarkedForHandoff for pending connection, got %v", err) } if acceptCon { t.Error("Should not accept connection with pending handoff")