mirror of
https://github.com/redis/go-redis.git
synced 2025-12-02 06:22:31 +03:00
fix handoff state when queued for handoff
This commit is contained in:
@@ -593,17 +593,40 @@ func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
|||||||
// This is called from OnPut hook, where the connection is typically in IN_USE state.
|
// This is called from OnPut hook, where the connection is typically in IN_USE state.
|
||||||
// The pool will preserve the UNUSABLE state and not overwrite it with IDLE.
|
// The pool will preserve the UNUSABLE state and not overwrite it with IDLE.
|
||||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||||
// Check if marked for handoff
|
// Get current handoff state
|
||||||
if !cn.ShouldHandoff() {
|
currentState := cn.handoffStateAtomic.Load()
|
||||||
|
if currentState == nil {
|
||||||
return errors.New("connection was not marked for handoff")
|
return errors.New("connection was not marked for handoff")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state := currentState.(*HandoffState)
|
||||||
|
if !state.ShouldHandoff {
|
||||||
|
return errors.New("connection was not marked for handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new state with ShouldHandoff=false but preserve endpoint and seqID
|
||||||
|
// This prevents the connection from being queued multiple times while still
|
||||||
|
// allowing the worker to access the handoff metadata
|
||||||
|
newState := &HandoffState{
|
||||||
|
ShouldHandoff: false,
|
||||||
|
Endpoint: state.Endpoint, // Preserve endpoint for handoff processing
|
||||||
|
SeqID: state.SeqID, // Preserve seqID for handoff processing
|
||||||
|
}
|
||||||
|
|
||||||
|
// Atomic compare-and-swap to update state
|
||||||
|
if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||||
|
// State changed between load and CAS - retry or return error
|
||||||
|
return errors.New("handoff state changed during marking")
|
||||||
|
}
|
||||||
|
|
||||||
// Transition to UNUSABLE from either IN_USE (normal flow) or IDLE (edge cases/tests)
|
// Transition to UNUSABLE from either IN_USE (normal flow) or IDLE (edge cases/tests)
|
||||||
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
|
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
|
||||||
// But in some edge cases or tests, it might already be in IDLE state
|
// But in some edge cases or tests, it might already be in IDLE state
|
||||||
// The pool will detect this state change and preserve it (not overwrite with IDLE)
|
// The pool will detect this state change and preserve it (not overwrite with IDLE)
|
||||||
_, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle}, StateUnusable)
|
_, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle}, StateUnusable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Restore the original state if transition fails
|
||||||
|
cn.handoffStateAtomic.Store(currentState)
|
||||||
return fmt.Errorf("failed to mark connection as unusable: %w", err)
|
return fmt.Errorf("failed to mark connection as unusable: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -349,13 +349,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
|
|||||||
|
|
||||||
newEndpoint := conn.GetHandoffEndpoint()
|
newEndpoint := conn.GetHandoffEndpoint()
|
||||||
if newEndpoint == "" {
|
if newEndpoint == "" {
|
||||||
// Empty endpoint means handoff to current endpoint (reconnect)
|
return false, ErrConnectionInvalidHandoffState
|
||||||
// Use the current connection's remote address
|
|
||||||
if conn.RemoteAddr() != nil {
|
|
||||||
newEndpoint = conn.RemoteAddr().String()
|
|
||||||
} else {
|
|
||||||
return false, ErrConnectionInvalidHandoffState
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use circuit breaker to protect against failing endpoints
|
// Use circuit breaker to protect against failing endpoints
|
||||||
|
|||||||
@@ -245,36 +245,35 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
defer processor.Shutdown(context.Background())
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create a mock pool that tracks removals
|
||||||
|
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
|
||||||
|
processor.SetPool(mockPool)
|
||||||
|
|
||||||
conn := createMockPoolConnection()
|
conn := createMockPoolConnection()
|
||||||
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
||||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set a mock initialization function
|
|
||||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("OnPut should not error with empty endpoint: %v", err)
|
t.Errorf("OnPut should not error with empty endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should pool the connection (empty endpoint triggers handoff to current endpoint)
|
// Should pool the connection (handoff will be queued and fail in worker)
|
||||||
if !shouldPool {
|
if !shouldPool {
|
||||||
t.Error("Connection should be pooled when handoff is queued")
|
t.Error("Connection should be pooled when handoff is queued")
|
||||||
}
|
}
|
||||||
if shouldRemove {
|
if shouldRemove {
|
||||||
t.Error("Connection should not be removed when handoff is queued")
|
t.Error("Connection should not be removed immediately")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for handoff to complete
|
// Wait for worker to process and fail
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
// After handoff completes, state should be cleared
|
// Connection should be removed from pool after handoff fails
|
||||||
if conn.ShouldHandoff() {
|
if !mockPool.WasRemoved(conn.GetID()) {
|
||||||
t.Error("Connection should not be marked for handoff after handoff completes")
|
t.Error("Connection should be removed from pool after empty endpoint handoff fails")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -404,12 +403,14 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
// Simulate a pending handoff by marking for handoff and queuing
|
// Simulate a pending handoff by marking for handoff and queuing
|
||||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
conn.MarkQueuedForHandoff() // Mark as queued (sets ShouldHandoff=false, state=UNUSABLE)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||||
if err != ErrConnectionMarkedForHandoffWithState {
|
// After MarkQueuedForHandoff, ShouldHandoff() returns false, so we get ErrConnectionMarkedForHandoff
|
||||||
t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
|
// (from IsUsable() check) instead of ErrConnectionMarkedForHandoffWithState
|
||||||
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
|
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||||
}
|
}
|
||||||
if acceptCon {
|
if acceptCon {
|
||||||
t.Error("Connection should not be accepted when marked for handoff")
|
t.Error("Connection should not be accepted when marked for handoff")
|
||||||
@@ -433,7 +434,7 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
// Test adding to pending map
|
// Test adding to pending map
|
||||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
conn.MarkQueuedForHandoff() // Mark as queued (sets ShouldHandoff=false, state=UNUSABLE)
|
||||||
|
|
||||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||||
t.Error("Connection should be in pending map")
|
t.Error("Connection should be in pending map")
|
||||||
@@ -442,8 +443,9 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
// Test OnGet with pending handoff
|
// Test OnGet with pending handoff
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||||
if err != ErrConnectionMarkedForHandoffWithState {
|
// After MarkQueuedForHandoff, ShouldHandoff() returns false, so we get ErrConnectionMarkedForHandoff
|
||||||
t.Errorf("Should return ErrConnectionMarkedForHandoffWithState for pending connection, got %v", err)
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
|
t.Errorf("Should return ErrConnectionMarkedForHandoff for pending connection, got %v", err)
|
||||||
}
|
}
|
||||||
if acceptCon {
|
if acceptCon {
|
||||||
t.Error("Should not accept connection with pending handoff")
|
t.Error("Should not accept connection with pending handoff")
|
||||||
|
|||||||
Reference in New Issue
Block a user