mirror of
https://github.com/redis/go-redis.git
synced 2025-12-02 06:22:31 +03:00
polish state machine
This commit is contained in:
@@ -200,7 +200,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
|
||||
// This ensures we only acquire connections that are not actively in use
|
||||
stateMachine := conn.GetStateMachine()
|
||||
if stateMachine != nil {
|
||||
err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
_, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
if err == nil {
|
||||
// Successfully acquired: connection was IDLE, now UNUSABLE
|
||||
acquired = true
|
||||
|
||||
@@ -30,7 +30,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) {
|
||||
}
|
||||
|
||||
// Try to transition to UNUSABLE (for reauth) - should fail
|
||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
if err == nil {
|
||||
t.Error("Expected error when trying to transition IN_USE → UNUSABLE, but got none")
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) {
|
||||
}
|
||||
|
||||
// Now try to transition to UNUSABLE - should succeed
|
||||
err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
_, err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to transition IDLE → UNUSABLE: %v", err)
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func TestReAuthWaitsForConnectionToBeIdle(t *testing.T) {
|
||||
default:
|
||||
reAuthAttempts.Add(1)
|
||||
// Try to atomically transition from IDLE to UNUSABLE
|
||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
if err == nil {
|
||||
// Successfully acquired
|
||||
acquired = true
|
||||
@@ -185,7 +185,7 @@ func TestConcurrentReAuthAndUsage(t *testing.T) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 50; i++ {
|
||||
// Try to acquire for re-auth
|
||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
if err == nil {
|
||||
reAuthCount.Add(1)
|
||||
// Simulate re-auth work
|
||||
@@ -228,7 +228,7 @@ func TestReAuthRespectsClosed(t *testing.T) {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
|
||||
// Try to transition to UNUSABLE - should fail
|
||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||
if err == nil {
|
||||
t.Error("Expected error when trying to transition CLOSED → UNUSABLE, but got none")
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
||||
if new {
|
||||
// Trying to make usable - transition from UNUSABLE to IDLE
|
||||
// This should only work from UNUSABLE or INITIALIZING states
|
||||
err := cn.stateMachine.TryTransition(
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
[]ConnState{StateInitializing, StateUnusable},
|
||||
StateIdle,
|
||||
)
|
||||
@@ -175,7 +175,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
||||
} else {
|
||||
// Trying to make unusable - transition from IDLE to UNUSABLE
|
||||
// This is typically for acquiring the connection for background operations
|
||||
err := cn.stateMachine.TryTransition(
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
[]ConnState{StateIdle},
|
||||
StateUnusable,
|
||||
)
|
||||
@@ -200,10 +200,13 @@ func (cn *Conn) IsUsable() bool {
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This should be called to mark a connection as usable after initialization or
|
||||
// to release it after a background operation completes.
|
||||
//
|
||||
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
|
||||
// Prefer CompareAndSwapUsed() when acquiring exclusive access to avoid race conditions.
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
if usable {
|
||||
// Transition to IDLE state (ready to be acquired)
|
||||
@@ -226,6 +229,9 @@ func (cn *Conn) IsInited() bool {
|
||||
|
||||
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This is the preferred method for acquiring a connection from the pool, as it
|
||||
// ensures that only one goroutine marks the connection as used.
|
||||
//
|
||||
@@ -242,17 +248,20 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
||||
|
||||
if !old && new {
|
||||
// Acquiring: IDLE → IN_USE
|
||||
err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
_, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
return err == nil
|
||||
} else {
|
||||
// Releasing: IN_USE → IDLE
|
||||
err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
_, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsUsed returns true if the connection is currently in use (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// A connection is "used" when it has been retrieved from the pool and is
|
||||
// actively processing a command. Background operations (like re-auth) should
|
||||
// wait until the connection is not used before executing commands.
|
||||
@@ -526,28 +535,32 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err
|
||||
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
|
||||
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
|
||||
// If another goroutine is initializing, we'll wait for it to finish
|
||||
err := cn.stateMachine.AwaitAndTransition(
|
||||
ctx,
|
||||
// Add 1ms timeout to prevent indefinite blocking
|
||||
waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
finalState, err := cn.stateMachine.AwaitAndTransition(
|
||||
waitCtx,
|
||||
[]ConnState{StateCreated, StateIdle, StateUnusable},
|
||||
StateInitializing,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot initialize connection from state %s: %w", cn.stateMachine.GetState(), err)
|
||||
return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
|
||||
}
|
||||
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
|
||||
// Execute initialization
|
||||
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
|
||||
// or CLOSED on failure. We don't need to do it here.
|
||||
initErr := cn.ExecuteInitConn(ctx)
|
||||
if initErr != nil {
|
||||
// Initialization failed - transition to closed
|
||||
cn.stateMachine.Transition(StateClosed)
|
||||
// ExecuteInitConn already transitioned to CLOSED, just return the error
|
||||
return initErr
|
||||
}
|
||||
|
||||
// Initialization succeeded - transition to IDLE (ready to be acquired)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
// ExecuteInitConn already transitioned to IDLE
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -577,7 +590,8 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
}
|
||||
|
||||
// Mark connection as unusable while queued for handoff
|
||||
cn.SetUsable(false)
|
||||
// Use state machine directly instead of deprecated SetUsable
|
||||
cn.stateMachine.Transition(StateUnusable)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -606,7 +620,8 @@ func (cn *Conn) ClearHandoffState() {
|
||||
cn.handoffRetriesAtomic.Store(0)
|
||||
|
||||
// Mark connection as usable again
|
||||
cn.SetUsable(true)
|
||||
// Use state machine directly instead of deprecated SetUsable
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
|
||||
@@ -113,19 +113,20 @@ func (sm *ConnStateMachine) GetState() ConnState {
|
||||
}
|
||||
|
||||
// TryTransition attempts an immediate state transition without waiting.
|
||||
// Returns ErrInvalidStateTransition if current state is not in validFromStates.
|
||||
// Returns the current state after the transition attempt and an error if the transition failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// This is faster than AwaitAndTransition when you don't need to wait.
|
||||
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
|
||||
// This method does NOT wait - it fails immediately if the transition cannot be performed.
|
||||
//
|
||||
// Performance: Zero allocations on success path (hot path).
|
||||
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) error {
|
||||
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
|
||||
// Try each valid from state with CAS
|
||||
// This ensures only ONE goroutine can successfully transition at a time
|
||||
for _, fromState := range validFromStates {
|
||||
// Fast path: check if we're already in target state
|
||||
if fromState == targetState && sm.GetState() == targetState {
|
||||
return nil
|
||||
return targetState, nil
|
||||
}
|
||||
|
||||
// Try to atomically swap from fromState to targetState
|
||||
@@ -134,14 +135,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta
|
||||
// Success! We transitioned atomically
|
||||
// Notify any waiters
|
||||
sm.notifyWaiters()
|
||||
return nil
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All CAS attempts failed - state is not valid for this transition
|
||||
// Return the current state so caller can decide what to do
|
||||
// Note: This error path allocates, but it's the exceptional case
|
||||
currentState := sm.GetState()
|
||||
return fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
||||
return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
||||
ErrInvalidStateTransition, currentState, targetState, validFromStates)
|
||||
}
|
||||
|
||||
@@ -155,6 +157,8 @@ func (sm *ConnStateMachine) Transition(targetState ConnState) {
|
||||
|
||||
// AwaitAndTransition waits for the connection to reach one of the valid states,
|
||||
// then atomically transitions to the target state.
|
||||
// Returns the current state after the transition attempt and an error if the operation failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// Returns error if timeout expires or context is cancelled.
|
||||
//
|
||||
// This method implements FIFO fairness - the first caller to wait gets priority
|
||||
@@ -167,19 +171,19 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
ctx context.Context,
|
||||
validFromStates []ConnState,
|
||||
targetState ConnState,
|
||||
) error {
|
||||
) (ConnState, error) {
|
||||
// Fast path: try immediate transition with CAS to prevent race conditions
|
||||
for _, fromState := range validFromStates {
|
||||
// Check if we're already in target state
|
||||
if fromState == targetState && sm.GetState() == targetState {
|
||||
return nil
|
||||
return targetState, nil
|
||||
}
|
||||
|
||||
// Try to atomically swap from fromState to targetState
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
sm.notifyWaiters()
|
||||
return nil
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +192,7 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
|
||||
// Check if closed
|
||||
if currentState == StateClosed {
|
||||
return ErrStateMachineClosed
|
||||
return currentState, ErrStateMachineClosed
|
||||
}
|
||||
|
||||
// Slow path: need to wait for state change
|
||||
@@ -216,9 +220,10 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
sm.mu.Lock()
|
||||
sm.waiters.Remove(elem)
|
||||
sm.mu.Unlock()
|
||||
return ctx.Err()
|
||||
return sm.GetState(), ctx.Err()
|
||||
case err := <-w.done:
|
||||
return err
|
||||
// Transition completed (or failed)
|
||||
return sm.GetState(), err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,27 +240,35 @@ func (sm *ConnStateMachine) notifyWaiters() {
|
||||
// Process waiters in FIFO order until no more can be processed
|
||||
// We loop instead of recursing to avoid stack overflow and mutex issues
|
||||
for {
|
||||
currentState := sm.GetState()
|
||||
processed := false
|
||||
|
||||
// Find the first waiter that can proceed
|
||||
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
|
||||
w := elem.Value.(*waiter)
|
||||
|
||||
// Read current state inside the loop to get the latest value
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if current state is valid for this waiter
|
||||
if _, valid := w.validStates[currentState]; valid {
|
||||
// Remove from queue
|
||||
// Remove from queue first
|
||||
sm.waiters.Remove(elem)
|
||||
|
||||
// Perform transition
|
||||
sm.state.Store(uint32(w.targetState))
|
||||
|
||||
// Notify waiter (non-blocking due to buffered channel)
|
||||
// Use CAS to ensure state hasn't changed since we checked
|
||||
// This prevents race condition where another thread changes state
|
||||
// between our check and our transition
|
||||
if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
|
||||
// Successfully transitioned - notify waiter
|
||||
w.done <- nil
|
||||
|
||||
// Mark that we processed a waiter and break to check for more
|
||||
processed = true
|
||||
break
|
||||
} else {
|
||||
// State changed - re-add waiter to front of queue and retry
|
||||
sm.waiters.PushFront(w)
|
||||
// Continue to next iteration to re-read state
|
||||
processed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestConnStateMachine_TryTransition(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(tt.initialState)
|
||||
|
||||
err := sm.TryTransition(tt.validStates, tt.targetState)
|
||||
_, err := sm.TryTransition(tt.validStates, tt.targetState)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("expected error but got none")
|
||||
@@ -99,7 +99,7 @@ func TestConnStateMachine_AwaitAndTransition_FastPath(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Fast path: already in valid state
|
||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func TestConnStateMachine_AwaitAndTransition_Timeout(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
// Wait for a state that will never come
|
||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
if err == nil {
|
||||
t.Error("expected timeout error but got none")
|
||||
}
|
||||
@@ -150,7 +150,7 @@ func TestConnStateMachine_AwaitAndTransition_FIFO(t *testing.T) {
|
||||
startBarrier.Wait()
|
||||
|
||||
ctx := context.Background()
|
||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle)
|
||||
if err != nil {
|
||||
t.Errorf("waiter %d got error: %v", waiterID, err)
|
||||
return
|
||||
@@ -206,7 +206,7 @@ func TestConnStateMachine_ConcurrentAccess(t *testing.T) {
|
||||
|
||||
for j := 0; j < numIterations; j++ {
|
||||
// Try to transition from READY to REAUTH_IN_PROGRESS
|
||||
err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
// Transition back to READY
|
||||
@@ -287,7 +287,7 @@ func TestConnStateMachine_PreventsConcurrentInitialization(t *testing.T) {
|
||||
startBarrier.Wait()
|
||||
|
||||
// Try to transition to INITIALIZING
|
||||
err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing)
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing)
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
|
||||
@@ -353,7 +353,7 @@ func TestConnStateMachine_AwaitAndTransitionWaitsForInitialization(t *testing.T)
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to transition to INITIALIZING - should wait if another is initializing
|
||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||
return
|
||||
@@ -419,7 +419,7 @@ func TestConnStateMachine_FIFOOrdering(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// This should queue in FIFO order
|
||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||
return
|
||||
@@ -482,7 +482,7 @@ func TestConnStateMachine_FIFOWithFastPath(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// This might use fast path (CAS) or slow path (queue)
|
||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||
return
|
||||
@@ -528,7 +528,7 @@ func BenchmarkConnStateMachine_TryTransition(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
_, _ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
@@ -543,7 +543,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
// Test IDLE → IN_USE transition
|
||||
err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IDLE to IN_USE: %v", err)
|
||||
}
|
||||
@@ -552,7 +552,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test IN_USE → IDLE transition
|
||||
err = sm.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
_, err = sm.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IN_USE to IDLE: %v", err)
|
||||
}
|
||||
@@ -570,7 +570,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
}
|
||||
@@ -641,7 +641,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
// Test IDLE → UNUSABLE transition (for background operations)
|
||||
err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err)
|
||||
}
|
||||
@@ -650,7 +650,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test UNUSABLE → IDLE transition (after background operation completes)
|
||||
err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle)
|
||||
_, err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err)
|
||||
}
|
||||
@@ -661,7 +661,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
||||
// Test that we can transition from IN_USE to UNUSABLE if needed
|
||||
// (e.g., for urgent handoff while connection is in use)
|
||||
sm.Transition(StateInUse)
|
||||
err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable)
|
||||
_, err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err)
|
||||
}
|
||||
@@ -672,7 +672,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
||||
// Test UNUSABLE → INITIALIZING transition (for handoff)
|
||||
sm.Transition(StateIdle)
|
||||
sm.Transition(StateUnusable)
|
||||
err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing)
|
||||
_, err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err)
|
||||
}
|
||||
|
||||
@@ -244,9 +244,9 @@ func (p *ConnPool) addIdleConn() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
// NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
|
||||
// when first acquired from the pool. Do NOT transition to IDLE here - that happens
|
||||
// after initialization completes.
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
@@ -286,9 +286,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
// NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
|
||||
// when first used. Do NOT transition to IDLE here - that happens after initialization completes.
|
||||
// The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success)
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||
_ = cn.Close()
|
||||
@@ -584,15 +584,17 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
}
|
||||
attempts++
|
||||
|
||||
if cn.CompareAndSwapUsed(false, true) {
|
||||
if cn.IsUsable() {
|
||||
// Try to atomically transition to IN_USE using state machine
|
||||
// Accept both CREATED (uninitialized) and IDLE (initialized) states
|
||||
_, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse)
|
||||
if err == nil {
|
||||
// Successfully acquired the connection
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
cn.SetUsed(false)
|
||||
}
|
||||
|
||||
// Connection is not usable, put it back in the pool
|
||||
// Connection is not in a valid state (might be UNUSABLE for re-auth, INITIALIZING, etc.)
|
||||
// Put it back in the pool
|
||||
if p.cfg.PoolFIFO {
|
||||
// FIFO: put at end (will be picked up last since we pop from front)
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
@@ -661,6 +663,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
var shouldCloseConn bool
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||
// Transition to IDLE state BEFORE adding to pool
|
||||
// This prevents race condition where another goroutine could acquire
|
||||
// a connection that's still in IN_USE state
|
||||
cn.GetStateMachine().Transition(StateIdle)
|
||||
|
||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||
// put them at the opposite end of the queue
|
||||
if !cn.IsUsable() {
|
||||
@@ -684,11 +691,6 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
shouldCloseConn = true
|
||||
}
|
||||
|
||||
// if the connection is not going to be closed, mark it as not used
|
||||
if !shouldCloseConn {
|
||||
cn.SetUsed(false)
|
||||
}
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
if shouldCloseConn {
|
||||
|
||||
@@ -44,6 +44,13 @@ func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
|
||||
if p.cn == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
|
||||
// - During initialization (connection is in INITIALIZING state)
|
||||
// - During re-authentication (connection is in UNUSABLE state)
|
||||
// - For transactions (connection might be in various states)
|
||||
// We use SetUsed() which forces the transition, rather than TryTransition() which
|
||||
// would fail if the connection is not in IDLE/CREATED state.
|
||||
p.cn.SetUsed(true)
|
||||
p.cn.SetUsedAt(time.Now())
|
||||
return p.cn, nil
|
||||
|
||||
133
redis.go
133
redis.go
@@ -366,28 +366,82 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if !cn.Inited.CompareAndSwap(false, true) {
|
||||
// This function is called in two scenarios:
|
||||
// 1. First-time init: Connection is in CREATED state (from pool.Get())
|
||||
// - We need to transition CREATED → INITIALIZING and do the initialization
|
||||
// - If another goroutine is already initializing, we WAIT for it to finish
|
||||
// 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn())
|
||||
// - We're already in INITIALIZING, so just proceed with initialization
|
||||
|
||||
currentState := cn.GetStateMachine().GetState()
|
||||
|
||||
// Fast path: Check if already initialized (IDLE or IN_USE)
|
||||
if currentState == pool.StateIdle || currentState == pool.StateInUse {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
|
||||
// If in CREATED state, try to transition to INITIALIZING
|
||||
if currentState == pool.StateCreated {
|
||||
finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing)
|
||||
if err != nil {
|
||||
// Another goroutine is initializing or connection is in unexpected state
|
||||
// Check what state we're in now
|
||||
if finalState == pool.StateIdle || finalState == pool.StateInUse {
|
||||
// Already initialized by another goroutine
|
||||
return nil
|
||||
}
|
||||
|
||||
if finalState == pool.StateInitializing {
|
||||
// Another goroutine is initializing - WAIT for it to complete
|
||||
// Use AwaitAndTransition to wait for IDLE or IN_USE state
|
||||
// Add 1ms timeout to prevent indefinite blocking
|
||||
waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
finalState, err := cn.GetStateMachine().AwaitAndTransition(
|
||||
waitCtx,
|
||||
[]pool.ConnState{pool.StateIdle, pool.StateInUse},
|
||||
pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Verify we're now initialized
|
||||
if finalState == pool.StateIdle || finalState == pool.StateInUse {
|
||||
return nil
|
||||
}
|
||||
// Unexpected state after waiting
|
||||
return fmt.Errorf("connection in unexpected state after initialization: %s", finalState)
|
||||
}
|
||||
|
||||
// Unexpected state (CLOSED, UNUSABLE, etc.)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, we're in INITIALIZING state and we own the initialization
|
||||
// If we fail, we must transition to CLOSED
|
||||
var initErr error
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credListener, err := c.streamingCredentialsManager.Listener(
|
||||
credListener, initErr := c.streamingCredentialsManager.Listener(
|
||||
cn,
|
||||
c.reAuthConnection(),
|
||||
c.onAuthenticationErr(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create credentials listener: %w", err)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to create credentials listener: %w", initErr)
|
||||
}
|
||||
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(credListener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr)
|
||||
}
|
||||
|
||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||
@@ -395,9 +449,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||
username, password, initErr = c.opt.CredentialsProviderContext(ctx)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", initErr)
|
||||
}
|
||||
} else if c.opt.CredentialsProvider != nil {
|
||||
username, password = c.opt.CredentialsProvider()
|
||||
@@ -407,9 +462,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
// for redis-server versions that do not support the HELLO command,
|
||||
// RESP2 will continue to be used.
|
||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil {
|
||||
// Authentication successful with HELLO command
|
||||
} else if !isRedisError(err) {
|
||||
} else if !isRedisError(initErr) {
|
||||
// When the server responds with the RESP protocol and the result is not a normal
|
||||
// execution result of the HELLO command, we consider it to be an indication that
|
||||
// the server does not support the HELLO command.
|
||||
@@ -417,20 +472,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
||||
// with different error string results for unsupported commands, making it
|
||||
// difficult to rely on error strings to determine all results.
|
||||
return err
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return initErr
|
||||
} else if password != "" {
|
||||
// Try legacy AUTH command if HELLO failed
|
||||
if username != "" {
|
||||
err = conn.AuthACL(ctx, username, password).Err()
|
||||
initErr = conn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = conn.Auth(ctx, password).Err()
|
||||
initErr = conn.Auth(ctx, password).Err()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to authenticate: %w", err)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to authenticate: %w", initErr)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
_, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
if c.opt.DB > 0 {
|
||||
pipe.Select(ctx, c.opt.DB)
|
||||
}
|
||||
@@ -445,8 +502,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to initialize connection options: %w", initErr)
|
||||
}
|
||||
|
||||
// Enable maintnotifications if maintnotifications are configured
|
||||
@@ -465,6 +523,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if maintNotifHandshakeErr != nil {
|
||||
if !isRedisError(maintNotifHandshakeErr) {
|
||||
// if not redis error, fail the connection
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return maintNotifHandshakeErr
|
||||
}
|
||||
c.optLock.Lock()
|
||||
@@ -473,15 +532,16 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
case maintnotifications.ModeEnabled:
|
||||
// enabled mode, fail the connection
|
||||
c.optLock.Unlock()
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
|
||||
default: // will handle auto and any other
|
||||
internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
|
||||
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
|
||||
c.optLock.Unlock()
|
||||
// auto mode, disable maintnotifications and continue
|
||||
if err := c.disableMaintNotificationsUpgrades(); err != nil {
|
||||
if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil {
|
||||
// Log error but continue - auto mode should be resilient
|
||||
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err)
|
||||
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -505,22 +565,31 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
|
||||
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
||||
// out of order responses later on.
|
||||
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
|
||||
return err
|
||||
if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return initErr
|
||||
}
|
||||
}
|
||||
|
||||
// mark the connection as usable and inited
|
||||
// once returned to the pool as idle, this connection can be used by other clients
|
||||
cn.SetUsable(true)
|
||||
cn.SetUsed(false)
|
||||
cn.Inited.Store(true)
|
||||
|
||||
// Set the connection initialization function for potential reconnections
|
||||
// This must be set before transitioning to IDLE so that handoff/reauth can use it
|
||||
cn.SetInitConnFunc(c.createInitConnFunc())
|
||||
|
||||
// Initialization succeeded - transition to IDLE state
|
||||
// This marks the connection as initialized and ready for use
|
||||
// NOTE: The connection is still owned by the calling goroutine at this point
|
||||
// and won't be available to other goroutines until it's Put() back into the pool
|
||||
cn.GetStateMachine().Transition(pool.StateIdle)
|
||||
|
||||
// Call OnConnect hook if configured
|
||||
// The connection is in IDLE state but still owned by this goroutine
|
||||
// If OnConnect needs to send commands, it can use the connection safely
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
if initErr = c.opt.OnConnect(ctx, conn); initErr != nil {
|
||||
// OnConnect failed - transition to closed
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return initErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user