diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index b29d7905..f9b9dd2c 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -200,7 +200,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, // This ensures we only acquire connections that are not actively in use stateMachine := conn.GetStateMachine() if stateMachine != nil { - err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) + _, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) if err == nil { // Successfully acquired: connection was IDLE, now UNUSABLE acquired = true diff --git a/internal/auth/streaming/pool_hook_state_test.go b/internal/auth/streaming/pool_hook_state_test.go index 3160f0f5..a4cfc328 100644 --- a/internal/auth/streaming/pool_hook_state_test.go +++ b/internal/auth/streaming/pool_hook_state_test.go @@ -30,7 +30,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) { } // Try to transition to UNUSABLE (for reauth) - should fail - err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) + _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) if err == nil { t.Error("Expected error when trying to transition IN_USE → UNUSABLE, but got none") } @@ -51,7 +51,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) { } // Now try to transition to UNUSABLE - should succeed - err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) + _, err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) if err != nil { t.Errorf("Failed to transition IDLE → UNUSABLE: %v", err) } @@ -99,7 +99,7 @@ func TestReAuthWaitsForConnectionToBeIdle(t *testing.T) { default: reAuthAttempts.Add(1) // Try to atomically transition from IDLE to UNUSABLE - err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) + _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) if err == nil { // Successfully acquired acquired = true @@ -185,7 +185,7 @@ func TestConcurrentReAuthAndUsage(t *testing.T) { defer wg.Done() for i := 0; i < 50; i++ { // Try to acquire for re-auth - err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) + _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) if err == nil { reAuthCount.Add(1) // Simulate re-auth work @@ -228,7 +228,7 @@ func TestReAuthRespectsClosed(t *testing.T) { cn.GetStateMachine().Transition(pool.StateClosed) // Try to transition to UNUSABLE - should fail - err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) + _, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable) if err == nil { t.Error("Expected error when trying to transition CLOSED → UNUSABLE, but got none") } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index cd3503d5..6755eda3 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -167,7 +167,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool { if new { // Trying to make usable - transition from UNUSABLE to IDLE // This should only work from UNUSABLE or INITIALIZING states - err := cn.stateMachine.TryTransition( + _, err := cn.stateMachine.TryTransition( []ConnState{StateInitializing, StateUnusable}, StateIdle, ) @@ -175,7 +175,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool { } else { // Trying to make unusable - transition from IDLE to UNUSABLE // This is typically for acquiring the connection for background operations - err := cn.stateMachine.TryTransition( + _, err := cn.stateMachine.TryTransition( []ConnState{StateIdle}, StateUnusable, ) @@ -200,10 +200,13 @@ func (cn *Conn) IsUsable() bool { // SetUsable sets the usable flag for the connection (lock-free). // +// Deprecated: Use GetStateMachine().Transition() directly for better state management. +// This method is kept for backwards compatibility. +// // This should be called to mark a connection as usable after initialization or // to release it after a background operation completes. // -// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions. +// Prefer CompareAndSwapUsed() when acquiring exclusive access to avoid race conditions. func (cn *Conn) SetUsable(usable bool) { if usable { // Transition to IDLE state (ready to be acquired) @@ -226,6 +229,9 @@ func (cn *Conn) IsInited() bool { // CompareAndSwapUsed atomically compares and swaps the used flag (lock-free). // +// Deprecated: Use GetStateMachine().TryTransition() directly for better state management. +// This method is kept for backwards compatibility. +// // This is the preferred method for acquiring a connection from the pool, as it // ensures that only one goroutine marks the connection as used. // @@ -242,17 +248,20 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool { if !old && new { // Acquiring: IDLE → IN_USE - err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse) + _, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse) return err == nil } else { // Releasing: IN_USE → IDLE - err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle) + _, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle) return err == nil } } // IsUsed returns true if the connection is currently in use (lock-free). // +// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity. +// This method is kept for backwards compatibility. +// // A connection is "used" when it has been retrieved from the pool and is // actively processing a command. Background operations (like re-auth) should // wait until the connection is not used before executing commands. @@ -526,28 +535,32 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err // Wait for and transition to INITIALIZING state - this prevents concurrent initializations // Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth) // If another goroutine is initializing, we'll wait for it to finish - err := cn.stateMachine.AwaitAndTransition( - ctx, + // Add 1ms timeout to prevent indefinite blocking + waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + + finalState, err := cn.stateMachine.AwaitAndTransition( + waitCtx, []ConnState{StateCreated, StateIdle, StateUnusable}, StateInitializing, ) if err != nil { - return fmt.Errorf("cannot initialize connection from state %s: %w", cn.stateMachine.GetState(), err) + return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err) } // Replace the underlying connection cn.SetNetConn(netConn) // Execute initialization + // NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success + // or CLOSED on failure. We don't need to do it here. initErr := cn.ExecuteInitConn(ctx) if initErr != nil { - // Initialization failed - transition to closed - cn.stateMachine.Transition(StateClosed) + // ExecuteInitConn already transitioned to CLOSED, just return the error return initErr } - // Initialization succeeded - transition to IDLE (ready to be acquired) - cn.stateMachine.Transition(StateIdle) + // ExecuteInitConn already transitioned to IDLE return nil } @@ -577,7 +590,8 @@ func (cn *Conn) MarkQueuedForHandoff() error { } // Mark connection as unusable while queued for handoff - cn.SetUsable(false) + // Use state machine directly instead of deprecated SetUsable + cn.stateMachine.Transition(StateUnusable) return nil } @@ -606,7 +620,8 @@ func (cn *Conn) ClearHandoffState() { cn.handoffRetriesAtomic.Store(0) // Mark connection as usable again - cn.SetUsable(true) + // Use state machine directly instead of deprecated SetUsable + cn.stateMachine.Transition(StateIdle) } // HasBufferedData safely checks if the connection has buffered data. diff --git a/internal/pool/conn_state.go b/internal/pool/conn_state.go index 89fac85f..7e29fcdd 100644 --- a/internal/pool/conn_state.go +++ b/internal/pool/conn_state.go @@ -113,19 +113,20 @@ func (sm *ConnStateMachine) GetState() ConnState { } // TryTransition attempts an immediate state transition without waiting. -// Returns ErrInvalidStateTransition if current state is not in validFromStates. +// Returns the current state after the transition attempt and an error if the transition failed. +// The returned state is the CURRENT state (after the attempt), not the previous state. // This is faster than AwaitAndTransition when you don't need to wait. // Uses compare-and-swap to atomically transition, preventing concurrent transitions. // This method does NOT wait - it fails immediately if the transition cannot be performed. // // Performance: Zero allocations on success path (hot path). -func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) error { +func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) { // Try each valid from state with CAS // This ensures only ONE goroutine can successfully transition at a time for _, fromState := range validFromStates { // Fast path: check if we're already in target state if fromState == targetState && sm.GetState() == targetState { - return nil + return targetState, nil } // Try to atomically swap from fromState to targetState @@ -134,14 +135,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta // Success! We transitioned atomically // Notify any waiters sm.notifyWaiters() - return nil + return targetState, nil } } // All CAS attempts failed - state is not valid for this transition + // Return the current state so caller can decide what to do // Note: This error path allocates, but it's the exceptional case currentState := sm.GetState() - return fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)", + return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)", ErrInvalidStateTransition, currentState, targetState, validFromStates) } @@ -155,6 +157,8 @@ func (sm *ConnStateMachine) Transition(targetState ConnState) { // AwaitAndTransition waits for the connection to reach one of the valid states, // then atomically transitions to the target state. +// Returns the current state after the transition attempt and an error if the operation failed. +// The returned state is the CURRENT state (after the attempt), not the previous state. // Returns error if timeout expires or context is cancelled. // // This method implements FIFO fairness - the first caller to wait gets priority @@ -167,19 +171,19 @@ func (sm *ConnStateMachine) AwaitAndTransition( ctx context.Context, validFromStates []ConnState, targetState ConnState, -) error { +) (ConnState, error) { // Fast path: try immediate transition with CAS to prevent race conditions for _, fromState := range validFromStates { // Check if we're already in target state if fromState == targetState && sm.GetState() == targetState { - return nil + return targetState, nil } // Try to atomically swap from fromState to targetState if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { // Success! We transitioned atomically sm.notifyWaiters() - return nil + return targetState, nil } } @@ -188,7 +192,7 @@ func (sm *ConnStateMachine) AwaitAndTransition( // Check if closed if currentState == StateClosed { - return ErrStateMachineClosed + return currentState, ErrStateMachineClosed } // Slow path: need to wait for state change @@ -216,9 +220,10 @@ func (sm *ConnStateMachine) AwaitAndTransition( sm.mu.Lock() sm.waiters.Remove(elem) sm.mu.Unlock() - return ctx.Err() + return sm.GetState(), ctx.Err() case err := <-w.done: - return err + // Transition completed (or failed) + return sm.GetState(), err } } @@ -235,27 +240,35 @@ func (sm *ConnStateMachine) notifyWaiters() { // Process waiters in FIFO order until no more can be processed // We loop instead of recursing to avoid stack overflow and mutex issues for { - currentState := sm.GetState() processed := false // Find the first waiter that can proceed for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() { w := elem.Value.(*waiter) + // Read current state inside the loop to get the latest value + currentState := sm.GetState() + // Check if current state is valid for this waiter if _, valid := w.validStates[currentState]; valid { - // Remove from queue + // Remove from queue first sm.waiters.Remove(elem) - // Perform transition - sm.state.Store(uint32(w.targetState)) - - // Notify waiter (non-blocking due to buffered channel) - w.done <- nil - - // Mark that we processed a waiter and break to check for more - processed = true - break + // Use CAS to ensure state hasn't changed since we checked + // This prevents race condition where another thread changes state + // between our check and our transition + if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) { + // Successfully transitioned - notify waiter + w.done <- nil + processed = true + break + } else { + // State changed - re-add waiter to front of queue and retry + sm.waiters.PushFront(w) + // Continue to next iteration to re-read state + processed = true + break + } } } diff --git a/internal/pool/conn_state_test.go b/internal/pool/conn_state_test.go index 1f2e23a7..40d83155 100644 --- a/internal/pool/conn_state_test.go +++ b/internal/pool/conn_state_test.go @@ -74,7 +74,7 @@ func TestConnStateMachine_TryTransition(t *testing.T) { sm := NewConnStateMachine() sm.Transition(tt.initialState) - err := sm.TryTransition(tt.validStates, tt.targetState) + _, err := sm.TryTransition(tt.validStates, tt.targetState) if tt.expectError && err == nil { t.Error("expected error but got none") @@ -99,7 +99,7 @@ func TestConnStateMachine_AwaitAndTransition_FastPath(t *testing.T) { ctx := context.Background() // Fast path: already in valid state - err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable) + _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable) if err != nil { t.Errorf("unexpected error: %v", err) } @@ -117,7 +117,7 @@ func TestConnStateMachine_AwaitAndTransition_Timeout(t *testing.T) { defer cancel() // Wait for a state that will never come - err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable) + _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable) if err == nil { t.Error("expected timeout error but got none") } @@ -150,7 +150,7 @@ func TestConnStateMachine_AwaitAndTransition_FIFO(t *testing.T) { startBarrier.Wait() ctx := context.Background() - err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle) + _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle) if err != nil { t.Errorf("waiter %d got error: %v", waiterID, err) return @@ -206,7 +206,7 @@ func TestConnStateMachine_ConcurrentAccess(t *testing.T) { for j := 0; j < numIterations; j++ { // Try to transition from READY to REAUTH_IN_PROGRESS - err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable) + _, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable) if err == nil { successCount.Add(1) // Transition back to READY @@ -287,7 +287,7 @@ func TestConnStateMachine_PreventsConcurrentInitialization(t *testing.T) { startBarrier.Wait() // Try to transition to INITIALIZING - err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing) + _, err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing) if err == nil { successCount.Add(1) @@ -353,7 +353,7 @@ func TestConnStateMachine_AwaitAndTransitionWaitsForInitialization(t *testing.T) ctx := context.Background() // Try to transition to INITIALIZING - should wait if another is initializing - err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) + _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) if err != nil { t.Errorf("Goroutine %d: failed to transition: %v", id, err) return @@ -419,7 +419,7 @@ func TestConnStateMachine_FIFOOrdering(t *testing.T) { ctx := context.Background() // This should queue in FIFO order - err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) + _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) if err != nil { t.Errorf("Goroutine %d: failed to transition: %v", id, err) return @@ -482,7 +482,7 @@ func TestConnStateMachine_FIFOWithFastPath(t *testing.T) { ctx := context.Background() // This might use fast path (CAS) or slow path (queue) - err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) + _, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing) if err != nil { t.Errorf("Goroutine %d: failed to transition: %v", id, err) return @@ -528,7 +528,7 @@ func BenchmarkConnStateMachine_TryTransition(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable) + _, _ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable) sm.Transition(StateIdle) } } @@ -543,7 +543,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) { sm.Transition(StateIdle) // Test IDLE → IN_USE transition - err := sm.TryTransition([]ConnState{StateIdle}, StateInUse) + _, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse) if err != nil { t.Errorf("failed to transition from IDLE to IN_USE: %v", err) } @@ -552,7 +552,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) { } // Test IN_USE → IDLE transition - err = sm.TryTransition([]ConnState{StateInUse}, StateIdle) + _, err = sm.TryTransition([]ConnState{StateInUse}, StateIdle) if err != nil { t.Errorf("failed to transition from IN_USE to IDLE: %v", err) } @@ -570,7 +570,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err := sm.TryTransition([]ConnState{StateIdle}, StateInUse) + _, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse) if err == nil { successCount.Add(1) } @@ -641,7 +641,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) { sm.Transition(StateIdle) // Test IDLE → UNUSABLE transition (for background operations) - err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable) + _, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable) if err != nil { t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err) } @@ -650,7 +650,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) { } // Test UNUSABLE → IDLE transition (after background operation completes) - err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle) + _, err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle) if err != nil { t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err) } @@ -661,7 +661,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) { // Test that we can transition from IN_USE to UNUSABLE if needed // (e.g., for urgent handoff while connection is in use) sm.Transition(StateInUse) - err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable) + _, err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable) if err != nil { t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err) } @@ -672,7 +672,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) { // Test UNUSABLE → INITIALIZING transition (for handoff) sm.Transition(StateIdle) sm.Transition(StateUnusable) - err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing) + _, err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing) if err != nil { t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err) } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 4915bf62..7ee02999 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -244,9 +244,9 @@ func (p *ConnPool) addIdleConn() error { return err } - // Mark connection as usable after successful creation - // This is essential for normal pool operations - cn.SetUsable(true) + // NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn() + // when first acquired from the pool. Do NOT transition to IDLE here - that happens + // after initialization completes. p.connsMu.Lock() defer p.connsMu.Unlock() @@ -286,9 +286,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, err } - // Mark connection as usable after successful creation - // This is essential for normal pool operations - cn.SetUsable(true) + // NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn() + // when first used. Do NOT transition to IDLE here - that happens after initialization completes. + // The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success) if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { _ = cn.Close() @@ -584,15 +584,17 @@ func (p *ConnPool) popIdle() (*Conn, error) { } attempts++ - if cn.CompareAndSwapUsed(false, true) { - if cn.IsUsable() { - p.idleConnsLen.Add(-1) - break - } - cn.SetUsed(false) + // Try to atomically transition to IN_USE using state machine + // Accept both CREATED (uninitialized) and IDLE (initialized) states + _, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse) + if err == nil { + // Successfully acquired the connection + p.idleConnsLen.Add(-1) + break } - // Connection is not usable, put it back in the pool + // Connection is not in a valid state (might be UNUSABLE for re-auth, INITIALIZING, etc.) + // Put it back in the pool if p.cfg.PoolFIFO { // FIFO: put at end (will be picked up last since we pop from front) p.idleConns = append(p.idleConns, cn) @@ -661,6 +663,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { var shouldCloseConn bool if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { + // Transition to IDLE state BEFORE adding to pool + // This prevents race condition where another goroutine could acquire + // a connection that's still in IN_USE state + cn.GetStateMachine().Transition(StateIdle) + // unusable conns are expected to become usable at some point (background process is reconnecting them) // put them at the opposite end of the queue if !cn.IsUsable() { @@ -684,11 +691,6 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { shouldCloseConn = true } - // if the connection is not going to be closed, mark it as not used - if !shouldCloseConn { - cn.SetUsed(false) - } - p.freeTurn() if shouldCloseConn { diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 712d482d..648e5ae4 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -44,6 +44,13 @@ func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) { if p.cn == nil { return nil, ErrClosed } + + // NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios: + // - During initialization (connection is in INITIALIZING state) + // - During re-authentication (connection is in UNUSABLE state) + // - For transactions (connection might be in various states) + // We use SetUsed() which forces the transition, rather than TryTransition() which + // would fail if the connection is not in IDLE/CREATED state. p.cn.SetUsed(true) p.cn.SetUsedAt(time.Now()) return p.cn, nil diff --git a/redis.go b/redis.go index dcd7b59a..9a1c8773 100644 --- a/redis.go +++ b/redis.go @@ -366,28 +366,82 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { } func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { - if !cn.Inited.CompareAndSwap(false, true) { + // This function is called in two scenarios: + // 1. First-time init: Connection is in CREATED state (from pool.Get()) + // - We need to transition CREATED → INITIALIZING and do the initialization + // - If another goroutine is already initializing, we WAIT for it to finish + // 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn()) + // - We're already in INITIALIZING, so just proceed with initialization + + currentState := cn.GetStateMachine().GetState() + + // Fast path: Check if already initialized (IDLE or IN_USE) + if currentState == pool.StateIdle || currentState == pool.StateInUse { return nil } - var err error + + // If in CREATED state, try to transition to INITIALIZING + if currentState == pool.StateCreated { + finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing) + if err != nil { + // Another goroutine is initializing or connection is in unexpected state + // Check what state we're in now + if finalState == pool.StateIdle || finalState == pool.StateInUse { + // Already initialized by another goroutine + return nil + } + + if finalState == pool.StateInitializing { + // Another goroutine is initializing - WAIT for it to complete + // Use AwaitAndTransition to wait for IDLE or IN_USE state + // Add 1ms timeout to prevent indefinite blocking + waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + + finalState, err := cn.GetStateMachine().AwaitAndTransition( + waitCtx, + []pool.ConnState{pool.StateIdle, pool.StateInUse}, + pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op) + ) + if err != nil { + return err + } + // Verify we're now initialized + if finalState == pool.StateIdle || finalState == pool.StateInUse { + return nil + } + // Unexpected state after waiting + return fmt.Errorf("connection in unexpected state after initialization: %s", finalState) + } + + // Unexpected state (CLOSED, UNUSABLE, etc.) + return err + } + } + + // At this point, we're in INITIALIZING state and we own the initialization + // If we fail, we must transition to CLOSED + var initErr error connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(c.opt, connPool, &c.hooksMixin) username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { - credListener, err := c.streamingCredentialsManager.Listener( + credListener, initErr := c.streamingCredentialsManager.Listener( cn, c.reAuthConnection(), c.onAuthenticationErr(), ) - if err != nil { - return fmt.Errorf("failed to create credentials listener: %w", err) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to create credentials listener: %w", initErr) } - credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. + credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider. Subscribe(credListener) - if err != nil { - return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr) } c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider) @@ -395,9 +449,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { username, password = credentials.BasicAuth() } else if c.opt.CredentialsProviderContext != nil { - username, password, err = c.opt.CredentialsProviderContext(ctx) - if err != nil { - return fmt.Errorf("failed to get credentials from context provider: %w", err) + username, password, initErr = c.opt.CredentialsProviderContext(ctx) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to get credentials from context provider: %w", initErr) } } else if c.opt.CredentialsProvider != nil { username, password = c.opt.CredentialsProvider() @@ -407,9 +462,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. - if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil { + if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil { // Authentication successful with HELLO command - } else if !isRedisError(err) { + } else if !isRedisError(initErr) { // When the server responds with the RESP protocol and the result is not a normal // execution result of the HELLO command, we consider it to be an indication that // the server does not support the HELLO command. @@ -417,20 +472,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // or it could be DragonflyDB or a third-party redis-proxy. They all respond // with different error string results for unsupported commands, making it // difficult to rely on error strings to determine all results. - return err + cn.GetStateMachine().Transition(pool.StateClosed) + return initErr } else if password != "" { // Try legacy AUTH command if HELLO failed if username != "" { - err = conn.AuthACL(ctx, username, password).Err() + initErr = conn.AuthACL(ctx, username, password).Err() } else { - err = conn.Auth(ctx, password).Err() + initErr = conn.Auth(ctx, password).Err() } - if err != nil { - return fmt.Errorf("failed to authenticate: %w", err) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to authenticate: %w", initErr) } } - _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { + _, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error { if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) } @@ -445,8 +502,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil }) - if err != nil { - return fmt.Errorf("failed to initialize connection options: %w", err) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to initialize connection options: %w", initErr) } // Enable maintnotifications if maintnotifications are configured @@ -465,6 +523,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { if maintNotifHandshakeErr != nil { if !isRedisError(maintNotifHandshakeErr) { // if not redis error, fail the connection + cn.GetStateMachine().Transition(pool.StateClosed) return maintNotifHandshakeErr } c.optLock.Lock() @@ -473,15 +532,16 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { case maintnotifications.ModeEnabled: // enabled mode, fail the connection c.optLock.Unlock() + cn.GetStateMachine().Transition(pool.StateClosed) return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr) default: // will handle auto and any other internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled c.optLock.Unlock() // auto mode, disable maintnotifications and continue - if err := c.disableMaintNotificationsUpgrades(); err != nil { + if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil { // Log error but continue - auto mode should be resilient - internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err) + internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr) } } } else { @@ -505,22 +565,31 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { p.ClientSetInfo(ctx, WithLibraryVersion(libVer)) // Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid // out of order responses later on. - if _, err = p.Exec(ctx); err != nil && !isRedisError(err) { - return err + if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) { + cn.GetStateMachine().Transition(pool.StateClosed) + return initErr } } - // mark the connection as usable and inited - // once returned to the pool as idle, this connection can be used by other clients - cn.SetUsable(true) - cn.SetUsed(false) - cn.Inited.Store(true) - // Set the connection initialization function for potential reconnections + // This must be set before transitioning to IDLE so that handoff/reauth can use it cn.SetInitConnFunc(c.createInitConnFunc()) + // Initialization succeeded - transition to IDLE state + // This marks the connection as initialized and ready for use + // NOTE: The connection is still owned by the calling goroutine at this point + // and won't be available to other goroutines until it's Put() back into the pool + cn.GetStateMachine().Transition(pool.StateIdle) + + // Call OnConnect hook if configured + // The connection is in IDLE state but still owned by this goroutine + // If OnConnect needs to send commands, it can use the connection safely if c.opt.OnConnect != nil { - return c.opt.OnConnect(ctx, conn) + if initErr = c.opt.OnConnect(ctx, conn); initErr != nil { + // OnConnect failed - transition to closed + cn.GetStateMachine().Transition(pool.StateClosed) + return initErr + } } return nil