1
0
mirror of https://github.com/redis/go-redis.git synced 2025-12-02 06:22:31 +03:00

polish state machine

This commit is contained in:
Nedyalko Dyakov
2025-10-24 13:09:30 +03:00
parent 0a754660ef
commit 5721512a79
8 changed files with 215 additions and 109 deletions

View File

@@ -200,7 +200,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
// This ensures we only acquire connections that are not actively in use // This ensures we only acquire connections that are not actively in use
stateMachine := conn.GetStateMachine() stateMachine := conn.GetStateMachine()
if stateMachine != nil { if stateMachine != nil {
err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) _, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil { if err == nil {
// Successfully acquired: connection was IDLE, now UNUSABLE // Successfully acquired: connection was IDLE, now UNUSABLE
acquired = true acquired = true

View File

@@ -30,7 +30,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) {
} }
// Try to transition to UNUSABLE (for reauth) - should fail // Try to transition to UNUSABLE (for reauth) - should fail
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil { if err == nil {
t.Error("Expected error when trying to transition IN_USE → UNUSABLE, but got none") t.Error("Expected error when trying to transition IN_USE → UNUSABLE, but got none")
} }
@@ -51,7 +51,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) {
} }
// Now try to transition to UNUSABLE - should succeed // Now try to transition to UNUSABLE - should succeed
err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) _, err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err != nil { if err != nil {
t.Errorf("Failed to transition IDLE → UNUSABLE: %v", err) t.Errorf("Failed to transition IDLE → UNUSABLE: %v", err)
} }
@@ -99,7 +99,7 @@ func TestReAuthWaitsForConnectionToBeIdle(t *testing.T) {
default: default:
reAuthAttempts.Add(1) reAuthAttempts.Add(1)
// Try to atomically transition from IDLE to UNUSABLE // Try to atomically transition from IDLE to UNUSABLE
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil { if err == nil {
// Successfully acquired // Successfully acquired
acquired = true acquired = true
@@ -185,7 +185,7 @@ func TestConcurrentReAuthAndUsage(t *testing.T) {
defer wg.Done() defer wg.Done()
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
// Try to acquire for re-auth // Try to acquire for re-auth
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil { if err == nil {
reAuthCount.Add(1) reAuthCount.Add(1)
// Simulate re-auth work // Simulate re-auth work
@@ -228,7 +228,7 @@ func TestReAuthRespectsClosed(t *testing.T) {
cn.GetStateMachine().Transition(pool.StateClosed) cn.GetStateMachine().Transition(pool.StateClosed)
// Try to transition to UNUSABLE - should fail // Try to transition to UNUSABLE - should fail
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
if err == nil { if err == nil {
t.Error("Expected error when trying to transition CLOSED → UNUSABLE, but got none") t.Error("Expected error when trying to transition CLOSED → UNUSABLE, but got none")
} }

View File

@@ -167,7 +167,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
if new { if new {
// Trying to make usable - transition from UNUSABLE to IDLE // Trying to make usable - transition from UNUSABLE to IDLE
// This should only work from UNUSABLE or INITIALIZING states // This should only work from UNUSABLE or INITIALIZING states
err := cn.stateMachine.TryTransition( _, err := cn.stateMachine.TryTransition(
[]ConnState{StateInitializing, StateUnusable}, []ConnState{StateInitializing, StateUnusable},
StateIdle, StateIdle,
) )
@@ -175,7 +175,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
} else { } else {
// Trying to make unusable - transition from IDLE to UNUSABLE // Trying to make unusable - transition from IDLE to UNUSABLE
// This is typically for acquiring the connection for background operations // This is typically for acquiring the connection for background operations
err := cn.stateMachine.TryTransition( _, err := cn.stateMachine.TryTransition(
[]ConnState{StateIdle}, []ConnState{StateIdle},
StateUnusable, StateUnusable,
) )
@@ -200,10 +200,13 @@ func (cn *Conn) IsUsable() bool {
// SetUsable sets the usable flag for the connection (lock-free). // SetUsable sets the usable flag for the connection (lock-free).
// //
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
// This method is kept for backwards compatibility.
//
// This should be called to mark a connection as usable after initialization or // This should be called to mark a connection as usable after initialization or
// to release it after a background operation completes. // to release it after a background operation completes.
// //
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions. // Prefer CompareAndSwapUsed() when acquiring exclusive access to avoid race conditions.
func (cn *Conn) SetUsable(usable bool) { func (cn *Conn) SetUsable(usable bool) {
if usable { if usable {
// Transition to IDLE state (ready to be acquired) // Transition to IDLE state (ready to be acquired)
@@ -226,6 +229,9 @@ func (cn *Conn) IsInited() bool {
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free). // CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
// //
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
// This method is kept for backwards compatibility.
//
// This is the preferred method for acquiring a connection from the pool, as it // This is the preferred method for acquiring a connection from the pool, as it
// ensures that only one goroutine marks the connection as used. // ensures that only one goroutine marks the connection as used.
// //
@@ -242,17 +248,20 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
if !old && new { if !old && new {
// Acquiring: IDLE → IN_USE // Acquiring: IDLE → IN_USE
err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse) _, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse)
return err == nil return err == nil
} else { } else {
// Releasing: IN_USE → IDLE // Releasing: IN_USE → IDLE
err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle) _, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle)
return err == nil return err == nil
} }
} }
// IsUsed returns true if the connection is currently in use (lock-free). // IsUsed returns true if the connection is currently in use (lock-free).
// //
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
// This method is kept for backwards compatibility.
//
// A connection is "used" when it has been retrieved from the pool and is // A connection is "used" when it has been retrieved from the pool and is
// actively processing a command. Background operations (like re-auth) should // actively processing a command. Background operations (like re-auth) should
// wait until the connection is not used before executing commands. // wait until the connection is not used before executing commands.
@@ -526,28 +535,32 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations // Wait for and transition to INITIALIZING state - this prevents concurrent initializations
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth) // Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
// If another goroutine is initializing, we'll wait for it to finish // If another goroutine is initializing, we'll wait for it to finish
err := cn.stateMachine.AwaitAndTransition( // Add 1ms timeout to prevent indefinite blocking
ctx, waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond)
defer cancel()
finalState, err := cn.stateMachine.AwaitAndTransition(
waitCtx,
[]ConnState{StateCreated, StateIdle, StateUnusable}, []ConnState{StateCreated, StateIdle, StateUnusable},
StateInitializing, StateInitializing,
) )
if err != nil { if err != nil {
return fmt.Errorf("cannot initialize connection from state %s: %w", cn.stateMachine.GetState(), err) return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
} }
// Replace the underlying connection // Replace the underlying connection
cn.SetNetConn(netConn) cn.SetNetConn(netConn)
// Execute initialization // Execute initialization
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
// or CLOSED on failure. We don't need to do it here.
initErr := cn.ExecuteInitConn(ctx) initErr := cn.ExecuteInitConn(ctx)
if initErr != nil { if initErr != nil {
// Initialization failed - transition to closed // ExecuteInitConn already transitioned to CLOSED, just return the error
cn.stateMachine.Transition(StateClosed)
return initErr return initErr
} }
// Initialization succeeded - transition to IDLE (ready to be acquired) // ExecuteInitConn already transitioned to IDLE
cn.stateMachine.Transition(StateIdle)
return nil return nil
} }
@@ -577,7 +590,8 @@ func (cn *Conn) MarkQueuedForHandoff() error {
} }
// Mark connection as unusable while queued for handoff // Mark connection as unusable while queued for handoff
cn.SetUsable(false) // Use state machine directly instead of deprecated SetUsable
cn.stateMachine.Transition(StateUnusable)
return nil return nil
} }
@@ -606,7 +620,8 @@ func (cn *Conn) ClearHandoffState() {
cn.handoffRetriesAtomic.Store(0) cn.handoffRetriesAtomic.Store(0)
// Mark connection as usable again // Mark connection as usable again
cn.SetUsable(true) // Use state machine directly instead of deprecated SetUsable
cn.stateMachine.Transition(StateIdle)
} }
// HasBufferedData safely checks if the connection has buffered data. // HasBufferedData safely checks if the connection has buffered data.

View File

@@ -113,19 +113,20 @@ func (sm *ConnStateMachine) GetState() ConnState {
} }
// TryTransition attempts an immediate state transition without waiting. // TryTransition attempts an immediate state transition without waiting.
// Returns ErrInvalidStateTransition if current state is not in validFromStates. // Returns the current state after the transition attempt and an error if the transition failed.
// The returned state is the CURRENT state (after the attempt), not the previous state.
// This is faster than AwaitAndTransition when you don't need to wait. // This is faster than AwaitAndTransition when you don't need to wait.
// Uses compare-and-swap to atomically transition, preventing concurrent transitions. // Uses compare-and-swap to atomically transition, preventing concurrent transitions.
// This method does NOT wait - it fails immediately if the transition cannot be performed. // This method does NOT wait - it fails immediately if the transition cannot be performed.
// //
// Performance: Zero allocations on success path (hot path). // Performance: Zero allocations on success path (hot path).
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) error { func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
// Try each valid from state with CAS // Try each valid from state with CAS
// This ensures only ONE goroutine can successfully transition at a time // This ensures only ONE goroutine can successfully transition at a time
for _, fromState := range validFromStates { for _, fromState := range validFromStates {
// Fast path: check if we're already in target state // Fast path: check if we're already in target state
if fromState == targetState && sm.GetState() == targetState { if fromState == targetState && sm.GetState() == targetState {
return nil return targetState, nil
} }
// Try to atomically swap from fromState to targetState // Try to atomically swap from fromState to targetState
@@ -134,14 +135,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta
// Success! We transitioned atomically // Success! We transitioned atomically
// Notify any waiters // Notify any waiters
sm.notifyWaiters() sm.notifyWaiters()
return nil return targetState, nil
} }
} }
// All CAS attempts failed - state is not valid for this transition // All CAS attempts failed - state is not valid for this transition
// Return the current state so caller can decide what to do
// Note: This error path allocates, but it's the exceptional case // Note: This error path allocates, but it's the exceptional case
currentState := sm.GetState() currentState := sm.GetState()
return fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)", return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
ErrInvalidStateTransition, currentState, targetState, validFromStates) ErrInvalidStateTransition, currentState, targetState, validFromStates)
} }
@@ -155,6 +157,8 @@ func (sm *ConnStateMachine) Transition(targetState ConnState) {
// AwaitAndTransition waits for the connection to reach one of the valid states, // AwaitAndTransition waits for the connection to reach one of the valid states,
// then atomically transitions to the target state. // then atomically transitions to the target state.
// Returns the current state after the transition attempt and an error if the operation failed.
// The returned state is the CURRENT state (after the attempt), not the previous state.
// Returns error if timeout expires or context is cancelled. // Returns error if timeout expires or context is cancelled.
// //
// This method implements FIFO fairness - the first caller to wait gets priority // This method implements FIFO fairness - the first caller to wait gets priority
@@ -167,19 +171,19 @@ func (sm *ConnStateMachine) AwaitAndTransition(
ctx context.Context, ctx context.Context,
validFromStates []ConnState, validFromStates []ConnState,
targetState ConnState, targetState ConnState,
) error { ) (ConnState, error) {
// Fast path: try immediate transition with CAS to prevent race conditions // Fast path: try immediate transition with CAS to prevent race conditions
for _, fromState := range validFromStates { for _, fromState := range validFromStates {
// Check if we're already in target state // Check if we're already in target state
if fromState == targetState && sm.GetState() == targetState { if fromState == targetState && sm.GetState() == targetState {
return nil return targetState, nil
} }
// Try to atomically swap from fromState to targetState // Try to atomically swap from fromState to targetState
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
// Success! We transitioned atomically // Success! We transitioned atomically
sm.notifyWaiters() sm.notifyWaiters()
return nil return targetState, nil
} }
} }
@@ -188,7 +192,7 @@ func (sm *ConnStateMachine) AwaitAndTransition(
// Check if closed // Check if closed
if currentState == StateClosed { if currentState == StateClosed {
return ErrStateMachineClosed return currentState, ErrStateMachineClosed
} }
// Slow path: need to wait for state change // Slow path: need to wait for state change
@@ -216,9 +220,10 @@ func (sm *ConnStateMachine) AwaitAndTransition(
sm.mu.Lock() sm.mu.Lock()
sm.waiters.Remove(elem) sm.waiters.Remove(elem)
sm.mu.Unlock() sm.mu.Unlock()
return ctx.Err() return sm.GetState(), ctx.Err()
case err := <-w.done: case err := <-w.done:
return err // Transition completed (or failed)
return sm.GetState(), err
} }
} }
@@ -235,27 +240,35 @@ func (sm *ConnStateMachine) notifyWaiters() {
// Process waiters in FIFO order until no more can be processed // Process waiters in FIFO order until no more can be processed
// We loop instead of recursing to avoid stack overflow and mutex issues // We loop instead of recursing to avoid stack overflow and mutex issues
for { for {
currentState := sm.GetState()
processed := false processed := false
// Find the first waiter that can proceed // Find the first waiter that can proceed
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() { for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
w := elem.Value.(*waiter) w := elem.Value.(*waiter)
// Read current state inside the loop to get the latest value
currentState := sm.GetState()
// Check if current state is valid for this waiter // Check if current state is valid for this waiter
if _, valid := w.validStates[currentState]; valid { if _, valid := w.validStates[currentState]; valid {
// Remove from queue // Remove from queue first
sm.waiters.Remove(elem) sm.waiters.Remove(elem)
// Perform transition // Use CAS to ensure state hasn't changed since we checked
sm.state.Store(uint32(w.targetState)) // This prevents race condition where another thread changes state
// between our check and our transition
// Notify waiter (non-blocking due to buffered channel) if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
w.done <- nil // Successfully transitioned - notify waiter
w.done <- nil
// Mark that we processed a waiter and break to check for more processed = true
processed = true break
break } else {
// State changed - re-add waiter to front of queue and retry
sm.waiters.PushFront(w)
// Continue to next iteration to re-read state
processed = true
break
}
} }
} }

View File

@@ -74,7 +74,7 @@ func TestConnStateMachine_TryTransition(t *testing.T) {
sm := NewConnStateMachine() sm := NewConnStateMachine()
sm.Transition(tt.initialState) sm.Transition(tt.initialState)
err := sm.TryTransition(tt.validStates, tt.targetState) _, err := sm.TryTransition(tt.validStates, tt.targetState)
if tt.expectError && err == nil { if tt.expectError && err == nil {
t.Error("expected error but got none") t.Error("expected error but got none")
@@ -99,7 +99,7 @@ func TestConnStateMachine_AwaitAndTransition_FastPath(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Fast path: already in valid state // Fast path: already in valid state
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable) _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
@@ -117,7 +117,7 @@ func TestConnStateMachine_AwaitAndTransition_Timeout(t *testing.T) {
defer cancel() defer cancel()
// Wait for a state that will never come // Wait for a state that will never come
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable) _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
if err == nil { if err == nil {
t.Error("expected timeout error but got none") t.Error("expected timeout error but got none")
} }
@@ -150,7 +150,7 @@ func TestConnStateMachine_AwaitAndTransition_FIFO(t *testing.T) {
startBarrier.Wait() startBarrier.Wait()
ctx := context.Background() ctx := context.Background()
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle) _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle)
if err != nil { if err != nil {
t.Errorf("waiter %d got error: %v", waiterID, err) t.Errorf("waiter %d got error: %v", waiterID, err)
return return
@@ -206,7 +206,7 @@ func TestConnStateMachine_ConcurrentAccess(t *testing.T) {
for j := 0; j < numIterations; j++ { for j := 0; j < numIterations; j++ {
// Try to transition from READY to REAUTH_IN_PROGRESS // Try to transition from READY to REAUTH_IN_PROGRESS
err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable) _, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
if err == nil { if err == nil {
successCount.Add(1) successCount.Add(1)
// Transition back to READY // Transition back to READY
@@ -287,7 +287,7 @@ func TestConnStateMachine_PreventsConcurrentInitialization(t *testing.T) {
startBarrier.Wait() startBarrier.Wait()
// Try to transition to INITIALIZING // Try to transition to INITIALIZING
err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing) _, err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing)
if err == nil { if err == nil {
successCount.Add(1) successCount.Add(1)
@@ -353,7 +353,7 @@ func TestConnStateMachine_AwaitAndTransitionWaitsForInitialization(t *testing.T)
ctx := context.Background() ctx := context.Background()
// Try to transition to INITIALIZING - should wait if another is initializing // Try to transition to INITIALIZING - should wait if another is initializing
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
if err != nil { if err != nil {
t.Errorf("Goroutine %d: failed to transition: %v", id, err) t.Errorf("Goroutine %d: failed to transition: %v", id, err)
return return
@@ -419,7 +419,7 @@ func TestConnStateMachine_FIFOOrdering(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// This should queue in FIFO order // This should queue in FIFO order
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
if err != nil { if err != nil {
t.Errorf("Goroutine %d: failed to transition: %v", id, err) t.Errorf("Goroutine %d: failed to transition: %v", id, err)
return return
@@ -482,7 +482,7 @@ func TestConnStateMachine_FIFOWithFastPath(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// This might use fast path (CAS) or slow path (queue) // This might use fast path (CAS) or slow path (queue)
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
if err != nil { if err != nil {
t.Errorf("Goroutine %d: failed to transition: %v", id, err) t.Errorf("Goroutine %d: failed to transition: %v", id, err)
return return
@@ -528,7 +528,7 @@ func BenchmarkConnStateMachine_TryTransition(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable) _, _ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
sm.Transition(StateIdle) sm.Transition(StateIdle)
} }
} }
@@ -543,7 +543,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
sm.Transition(StateIdle) sm.Transition(StateIdle)
// Test IDLE → IN_USE transition // Test IDLE → IN_USE transition
err := sm.TryTransition([]ConnState{StateIdle}, StateInUse) _, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
if err != nil { if err != nil {
t.Errorf("failed to transition from IDLE to IN_USE: %v", err) t.Errorf("failed to transition from IDLE to IN_USE: %v", err)
} }
@@ -552,7 +552,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
} }
// Test IN_USE → IDLE transition // Test IN_USE → IDLE transition
err = sm.TryTransition([]ConnState{StateInUse}, StateIdle) _, err = sm.TryTransition([]ConnState{StateInUse}, StateIdle)
if err != nil { if err != nil {
t.Errorf("failed to transition from IN_USE to IDLE: %v", err) t.Errorf("failed to transition from IN_USE to IDLE: %v", err)
} }
@@ -570,7 +570,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := sm.TryTransition([]ConnState{StateIdle}, StateInUse) _, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
if err == nil { if err == nil {
successCount.Add(1) successCount.Add(1)
} }
@@ -641,7 +641,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
sm.Transition(StateIdle) sm.Transition(StateIdle)
// Test IDLE → UNUSABLE transition (for background operations) // Test IDLE → UNUSABLE transition (for background operations)
err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable) _, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
if err != nil { if err != nil {
t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err) t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err)
} }
@@ -650,7 +650,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
} }
// Test UNUSABLE → IDLE transition (after background operation completes) // Test UNUSABLE → IDLE transition (after background operation completes)
err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle) _, err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle)
if err != nil { if err != nil {
t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err) t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err)
} }
@@ -661,7 +661,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
// Test that we can transition from IN_USE to UNUSABLE if needed // Test that we can transition from IN_USE to UNUSABLE if needed
// (e.g., for urgent handoff while connection is in use) // (e.g., for urgent handoff while connection is in use)
sm.Transition(StateInUse) sm.Transition(StateInUse)
err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable) _, err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable)
if err != nil { if err != nil {
t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err) t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err)
} }
@@ -672,7 +672,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
// Test UNUSABLE → INITIALIZING transition (for handoff) // Test UNUSABLE → INITIALIZING transition (for handoff)
sm.Transition(StateIdle) sm.Transition(StateIdle)
sm.Transition(StateUnusable) sm.Transition(StateUnusable)
err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing) _, err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing)
if err != nil { if err != nil {
t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err) t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err)
} }

View File

@@ -244,9 +244,9 @@ func (p *ConnPool) addIdleConn() error {
return err return err
} }
// Mark connection as usable after successful creation // NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
// This is essential for normal pool operations // when first acquired from the pool. Do NOT transition to IDLE here - that happens
cn.SetUsable(true) // after initialization completes.
p.connsMu.Lock() p.connsMu.Lock()
defer p.connsMu.Unlock() defer p.connsMu.Unlock()
@@ -286,9 +286,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, err return nil, err
} }
// Mark connection as usable after successful creation // NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
// This is essential for normal pool operations // when first used. Do NOT transition to IDLE here - that happens after initialization completes.
cn.SetUsable(true) // The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success)
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
_ = cn.Close() _ = cn.Close()
@@ -584,15 +584,17 @@ func (p *ConnPool) popIdle() (*Conn, error) {
} }
attempts++ attempts++
if cn.CompareAndSwapUsed(false, true) { // Try to atomically transition to IN_USE using state machine
if cn.IsUsable() { // Accept both CREATED (uninitialized) and IDLE (initialized) states
p.idleConnsLen.Add(-1) _, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse)
break if err == nil {
} // Successfully acquired the connection
cn.SetUsed(false) p.idleConnsLen.Add(-1)
break
} }
// Connection is not usable, put it back in the pool // Connection is not in a valid state (might be UNUSABLE for re-auth, INITIALIZING, etc.)
// Put it back in the pool
if p.cfg.PoolFIFO { if p.cfg.PoolFIFO {
// FIFO: put at end (will be picked up last since we pop from front) // FIFO: put at end (will be picked up last since we pop from front)
p.idleConns = append(p.idleConns, cn) p.idleConns = append(p.idleConns, cn)
@@ -661,6 +663,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
var shouldCloseConn bool var shouldCloseConn bool
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// Transition to IDLE state BEFORE adding to pool
// This prevents race condition where another goroutine could acquire
// a connection that's still in IN_USE state
cn.GetStateMachine().Transition(StateIdle)
// unusable conns are expected to become usable at some point (background process is reconnecting them) // unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue // put them at the opposite end of the queue
if !cn.IsUsable() { if !cn.IsUsable() {
@@ -684,11 +691,6 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
shouldCloseConn = true shouldCloseConn = true
} }
// if the connection is not going to be closed, mark it as not used
if !shouldCloseConn {
cn.SetUsed(false)
}
p.freeTurn() p.freeTurn()
if shouldCloseConn { if shouldCloseConn {

View File

@@ -44,6 +44,13 @@ func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
if p.cn == nil { if p.cn == nil {
return nil, ErrClosed return nil, ErrClosed
} }
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
// - During initialization (connection is in INITIALIZING state)
// - During re-authentication (connection is in UNUSABLE state)
// - For transactions (connection might be in various states)
// We use SetUsed() which forces the transition, rather than TryTransition() which
// would fail if the connection is not in IDLE/CREATED state.
p.cn.SetUsed(true) p.cn.SetUsed(true)
p.cn.SetUsedAt(time.Now()) p.cn.SetUsedAt(time.Now())
return p.cn, nil return p.cn, nil

133
redis.go
View File

@@ -366,28 +366,82 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
} }
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if !cn.Inited.CompareAndSwap(false, true) { // This function is called in two scenarios:
// 1. First-time init: Connection is in CREATED state (from pool.Get())
// - We need to transition CREATED → INITIALIZING and do the initialization
// - If another goroutine is already initializing, we WAIT for it to finish
// 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn())
// - We're already in INITIALIZING, so just proceed with initialization
currentState := cn.GetStateMachine().GetState()
// Fast path: Check if already initialized (IDLE or IN_USE)
if currentState == pool.StateIdle || currentState == pool.StateInUse {
return nil return nil
} }
var err error
// If in CREATED state, try to transition to INITIALIZING
if currentState == pool.StateCreated {
finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing)
if err != nil {
// Another goroutine is initializing or connection is in unexpected state
// Check what state we're in now
if finalState == pool.StateIdle || finalState == pool.StateInUse {
// Already initialized by another goroutine
return nil
}
if finalState == pool.StateInitializing {
// Another goroutine is initializing - WAIT for it to complete
// Use AwaitAndTransition to wait for IDLE or IN_USE state
// Add 1ms timeout to prevent indefinite blocking
waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond)
defer cancel()
finalState, err := cn.GetStateMachine().AwaitAndTransition(
waitCtx,
[]pool.ConnState{pool.StateIdle, pool.StateInUse},
pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op)
)
if err != nil {
return err
}
// Verify we're now initialized
if finalState == pool.StateIdle || finalState == pool.StateInUse {
return nil
}
// Unexpected state after waiting
return fmt.Errorf("connection in unexpected state after initialization: %s", finalState)
}
// Unexpected state (CLOSED, UNUSABLE, etc.)
return err
}
}
// At this point, we're in INITIALIZING state and we own the initialization
// If we fail, we must transition to CLOSED
var initErr error
connPool := pool.NewSingleConnPool(c.connPool, cn) connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin) conn := newConn(c.opt, connPool, &c.hooksMixin)
username, password := "", "" username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil { if c.opt.StreamingCredentialsProvider != nil {
credListener, err := c.streamingCredentialsManager.Listener( credListener, initErr := c.streamingCredentialsManager.Listener(
cn, cn,
c.reAuthConnection(), c.reAuthConnection(),
c.onAuthenticationErr(), c.onAuthenticationErr(),
) )
if err != nil { if initErr != nil {
return fmt.Errorf("failed to create credentials listener: %w", err) cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to create credentials listener: %w", initErr)
} }
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider.
Subscribe(credListener) Subscribe(credListener)
if err != nil { if initErr != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr)
} }
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider) c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
@@ -395,9 +449,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
username, password = credentials.BasicAuth() username, password = credentials.BasicAuth()
} else if c.opt.CredentialsProviderContext != nil { } else if c.opt.CredentialsProviderContext != nil {
username, password, err = c.opt.CredentialsProviderContext(ctx) username, password, initErr = c.opt.CredentialsProviderContext(ctx)
if err != nil { if initErr != nil {
return fmt.Errorf("failed to get credentials from context provider: %w", err) cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to get credentials from context provider: %w", initErr)
} }
} else if c.opt.CredentialsProvider != nil { } else if c.opt.CredentialsProvider != nil {
username, password = c.opt.CredentialsProvider() username, password = c.opt.CredentialsProvider()
@@ -407,9 +462,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// for redis-server versions that do not support the HELLO command, // for redis-server versions that do not support the HELLO command,
// RESP2 will continue to be used. // RESP2 will continue to be used.
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil { if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil {
// Authentication successful with HELLO command // Authentication successful with HELLO command
} else if !isRedisError(err) { } else if !isRedisError(initErr) {
// When the server responds with the RESP protocol and the result is not a normal // When the server responds with the RESP protocol and the result is not a normal
// execution result of the HELLO command, we consider it to be an indication that // execution result of the HELLO command, we consider it to be an indication that
// the server does not support the HELLO command. // the server does not support the HELLO command.
@@ -417,20 +472,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// or it could be DragonflyDB or a third-party redis-proxy. They all respond // or it could be DragonflyDB or a third-party redis-proxy. They all respond
// with different error string results for unsupported commands, making it // with different error string results for unsupported commands, making it
// difficult to rely on error strings to determine all results. // difficult to rely on error strings to determine all results.
return err cn.GetStateMachine().Transition(pool.StateClosed)
return initErr
} else if password != "" { } else if password != "" {
// Try legacy AUTH command if HELLO failed // Try legacy AUTH command if HELLO failed
if username != "" { if username != "" {
err = conn.AuthACL(ctx, username, password).Err() initErr = conn.AuthACL(ctx, username, password).Err()
} else { } else {
err = conn.Auth(ctx, password).Err() initErr = conn.Auth(ctx, password).Err()
} }
if err != nil { if initErr != nil {
return fmt.Errorf("failed to authenticate: %w", err) cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to authenticate: %w", initErr)
} }
} }
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { _, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error {
if c.opt.DB > 0 { if c.opt.DB > 0 {
pipe.Select(ctx, c.opt.DB) pipe.Select(ctx, c.opt.DB)
} }
@@ -445,8 +502,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return nil return nil
}) })
if err != nil { if initErr != nil {
return fmt.Errorf("failed to initialize connection options: %w", err) cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to initialize connection options: %w", initErr)
} }
// Enable maintnotifications if maintnotifications are configured // Enable maintnotifications if maintnotifications are configured
@@ -465,6 +523,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if maintNotifHandshakeErr != nil { if maintNotifHandshakeErr != nil {
if !isRedisError(maintNotifHandshakeErr) { if !isRedisError(maintNotifHandshakeErr) {
// if not redis error, fail the connection // if not redis error, fail the connection
cn.GetStateMachine().Transition(pool.StateClosed)
return maintNotifHandshakeErr return maintNotifHandshakeErr
} }
c.optLock.Lock() c.optLock.Lock()
@@ -473,15 +532,16 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
case maintnotifications.ModeEnabled: case maintnotifications.ModeEnabled:
// enabled mode, fail the connection // enabled mode, fail the connection
c.optLock.Unlock() c.optLock.Unlock()
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr) return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
default: // will handle auto and any other default: // will handle auto and any other
internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
c.optLock.Unlock() c.optLock.Unlock()
// auto mode, disable maintnotifications and continue // auto mode, disable maintnotifications and continue
if err := c.disableMaintNotificationsUpgrades(); err != nil { if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil {
// Log error but continue - auto mode should be resilient // Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err) internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr)
} }
} }
} else { } else {
@@ -505,22 +565,31 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
p.ClientSetInfo(ctx, WithLibraryVersion(libVer)) p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid // Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
// out of order responses later on. // out of order responses later on.
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) { if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) {
return err cn.GetStateMachine().Transition(pool.StateClosed)
return initErr
} }
} }
// mark the connection as usable and inited
// once returned to the pool as idle, this connection can be used by other clients
cn.SetUsable(true)
cn.SetUsed(false)
cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections // Set the connection initialization function for potential reconnections
// This must be set before transitioning to IDLE so that handoff/reauth can use it
cn.SetInitConnFunc(c.createInitConnFunc()) cn.SetInitConnFunc(c.createInitConnFunc())
// Initialization succeeded - transition to IDLE state
// This marks the connection as initialized and ready for use
// NOTE: The connection is still owned by the calling goroutine at this point
// and won't be available to other goroutines until it's Put() back into the pool
cn.GetStateMachine().Transition(pool.StateIdle)
// Call OnConnect hook if configured
// The connection is in IDLE state but still owned by this goroutine
// If OnConnect needs to send commands, it can use the connection safely
if c.opt.OnConnect != nil { if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn) if initErr = c.opt.OnConnect(ctx, conn); initErr != nil {
// OnConnect failed - transition to closed
cn.GetStateMachine().Transition(pool.StateClosed)
return initErr
}
} }
return nil return nil