1
0
mirror of https://github.com/redis/go-redis.git synced 2025-12-02 06:22:31 +03:00

improve reauth state management. fix tests

This commit is contained in:
Nedyalko Dyakov
2025-10-24 15:05:54 +03:00
parent 92433e6f2a
commit 21bd243bf5
3 changed files with 24 additions and 46 deletions

View File

@@ -179,40 +179,27 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
r.workers <- struct{}{}
}()
var err error
timeout := time.After(r.reAuthTimeout)
// Create timeout context for connection acquisition
// This prevents indefinite waiting if the connection is stuck
ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout)
defer cancel()
// Try to acquire the connection for re-authentication
// We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE
// This prevents re-authentication from interfering with active commands
const baseDelay = 10 * time.Microsecond
acquired := false
attempt := 0
for !acquired {
select {
case <-timeout:
// Timeout occurred, cannot acquire connection
err = pool.ErrConnUnusableTimeout
reAuthFn(err)
return
default:
// Try to atomically transition from IDLE to UNUSABLE
// This ensures we only acquire connections that are not actively in use
stateMachine := conn.GetStateMachine()
if stateMachine != nil {
_, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil {
// Successfully acquired: connection was IDLE, now UNUSABLE
acquired = true
}
}
if !acquired {
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
attempt++
}
}
// Use AwaitAndTransition to wait for the connection to become IDLE
stateMachine := conn.GetStateMachine()
if stateMachine == nil {
// No state machine - should not happen, but handle gracefully
reAuthFn(pool.ErrConnUnusableTimeout)
return
}
_, err := stateMachine.AwaitAndTransition(ctx, []pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err != nil {
// Timeout or other error occurred, cannot acquire connection
reAuthFn(err)
return
}
// safety first
@@ -222,10 +209,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
}
// Release the connection: transition from UNUSABLE back to IDLE
stateMachine := conn.GetStateMachine()
if stateMachine != nil {
stateMachine.Transition(pool.StateIdle)
}
stateMachine.Transition(pool.StateIdle)
}()
}

View File

@@ -256,12 +256,6 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Get handoff info atomically to prevent race conditions
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
// Special case: empty endpoint means clear handoff state
if endpoint == "" {
conn.ClearHandoffState()
return nil
}
// on retries the connection will not be marked for handoff, but it will have retries > 0
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
if !shouldHandoff && conn.HandoffRetries() == 0 {

View File

@@ -391,8 +391,8 @@ func TestConnectionHook(t *testing.T) {
ctx := context.Background()
acceptCon, err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
if err != ErrConnectionMarkedForHandoffWithState {
t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
}
if acceptCon {
t.Error("Connection should not be accepted when marked for handoff")
@@ -425,8 +425,8 @@ func TestConnectionHook(t *testing.T) {
// Test OnGet with pending handoff
ctx := context.Background()
acceptCon, err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
if err != ErrConnectionMarkedForHandoffWithState {
t.Errorf("Should return ErrConnectionMarkedForHandoffWithState for pending connection, got %v", err)
}
if acceptCon {
t.Error("Should not accept connection with pending handoff")
@@ -678,8 +678,8 @@ func TestConnectionHook(t *testing.T) {
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
if err != ErrConnectionMarkedForHandoffWithState {
t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
}
if acceptConn {
t.Error("Connection should not be accepted when marked for handoff")