1
0
mirror of https://github.com/redis/go-redis.git synced 2025-12-03 18:31:14 +03:00

improve reauth state management. fix tests

This commit is contained in:
Nedyalko Dyakov
2025-10-24 15:05:54 +03:00
parent 92433e6f2a
commit 21bd243bf5
3 changed files with 24 additions and 46 deletions

View File

@@ -179,40 +179,27 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
r.workers <- struct{}{} r.workers <- struct{}{}
}() }()
var err error // Create timeout context for connection acquisition
timeout := time.After(r.reAuthTimeout) // This prevents indefinite waiting if the connection is stuck
ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout)
defer cancel()
// Try to acquire the connection for re-authentication // Try to acquire the connection for re-authentication
// We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE // We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE
// This prevents re-authentication from interfering with active commands // This prevents re-authentication from interfering with active commands
const baseDelay = 10 * time.Microsecond // Use AwaitAndTransition to wait for the connection to become IDLE
acquired := false stateMachine := conn.GetStateMachine()
attempt := 0 if stateMachine == nil {
for !acquired { // No state machine - should not happen, but handle gracefully
select { reAuthFn(pool.ErrConnUnusableTimeout)
case <-timeout: return
// Timeout occurred, cannot acquire connection }
err = pool.ErrConnUnusableTimeout
reAuthFn(err) _, err := stateMachine.AwaitAndTransition(ctx, []pool.ConnState{pool.StateIdle}, pool.StateUnusable)
return if err != nil {
default: // Timeout or other error occurred, cannot acquire connection
// Try to atomically transition from IDLE to UNUSABLE reAuthFn(err)
// This ensures we only acquire connections that are not actively in use return
stateMachine := conn.GetStateMachine()
if stateMachine != nil {
_, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil {
// Successfully acquired: connection was IDLE, now UNUSABLE
acquired = true
}
}
if !acquired {
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
attempt++
}
}
} }
// safety first // safety first
@@ -222,10 +209,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
} }
// Release the connection: transition from UNUSABLE back to IDLE // Release the connection: transition from UNUSABLE back to IDLE
stateMachine := conn.GetStateMachine() stateMachine.Transition(pool.StateIdle)
if stateMachine != nil {
stateMachine.Transition(pool.StateIdle)
}
}() }()
} }

View File

@@ -256,12 +256,6 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Get handoff info atomically to prevent race conditions // Get handoff info atomically to prevent race conditions
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo() shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
// Special case: empty endpoint means clear handoff state
if endpoint == "" {
conn.ClearHandoffState()
return nil
}
// on retries the connection will not be marked for handoff, but it will have retries > 0 // on retries the connection will not be marked for handoff, but it will have retries > 0
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff // if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
if !shouldHandoff && conn.HandoffRetries() == 0 { if !shouldHandoff && conn.HandoffRetries() == 0 {

View File

@@ -391,8 +391,8 @@ func TestConnectionHook(t *testing.T) {
ctx := context.Background() ctx := context.Background()
acceptCon, err := processor.OnGet(ctx, conn, false) acceptCon, err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff { if err != ErrConnectionMarkedForHandoffWithState {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
} }
if acceptCon { if acceptCon {
t.Error("Connection should not be accepted when marked for handoff") t.Error("Connection should not be accepted when marked for handoff")
@@ -425,8 +425,8 @@ func TestConnectionHook(t *testing.T) {
// Test OnGet with pending handoff // Test OnGet with pending handoff
ctx := context.Background() ctx := context.Background()
acceptCon, err := processor.OnGet(ctx, conn, false) acceptCon, err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff { if err != ErrConnectionMarkedForHandoffWithState {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") t.Errorf("Should return ErrConnectionMarkedForHandoffWithState for pending connection, got %v", err)
} }
if acceptCon { if acceptCon {
t.Error("Should not accept connection with pending handoff") t.Error("Should not accept connection with pending handoff")
@@ -678,8 +678,8 @@ func TestConnectionHook(t *testing.T) {
if err == nil { if err == nil {
t.Error("OnGet should fail for connection marked for handoff") t.Error("OnGet should fail for connection marked for handoff")
} }
if err != ErrConnectionMarkedForHandoff { if err != ErrConnectionMarkedForHandoffWithState {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
} }
if acceptConn { if acceptConn {
t.Error("Connection should not be accepted when marked for handoff") t.Error("Connection should not be accepted when marked for handoff")