mirror of
https://github.com/redis/go-redis.git
synced 2025-12-03 18:31:14 +03:00
polish state machine
This commit is contained in:
@@ -200,7 +200,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
|
|||||||
// This ensures we only acquire connections that are not actively in use
|
// This ensures we only acquire connections that are not actively in use
|
||||||
stateMachine := conn.GetStateMachine()
|
stateMachine := conn.GetStateMachine()
|
||||||
if stateMachine != nil {
|
if stateMachine != nil {
|
||||||
err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
_, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Successfully acquired: connection was IDLE, now UNUSABLE
|
// Successfully acquired: connection was IDLE, now UNUSABLE
|
||||||
acquired = true
|
acquired = true
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to transition to UNUSABLE (for reauth) - should fail
|
// Try to transition to UNUSABLE (for reauth) - should fail
|
||||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when trying to transition IN_USE → UNUSABLE, but got none")
|
t.Error("Expected error when trying to transition IN_USE → UNUSABLE, but got none")
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ func TestReAuthOnlyWhenIdle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now try to transition to UNUSABLE - should succeed
|
// Now try to transition to UNUSABLE - should succeed
|
||||||
err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
_, err = cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to transition IDLE → UNUSABLE: %v", err)
|
t.Errorf("Failed to transition IDLE → UNUSABLE: %v", err)
|
||||||
}
|
}
|
||||||
@@ -99,7 +99,7 @@ func TestReAuthWaitsForConnectionToBeIdle(t *testing.T) {
|
|||||||
default:
|
default:
|
||||||
reAuthAttempts.Add(1)
|
reAuthAttempts.Add(1)
|
||||||
// Try to atomically transition from IDLE to UNUSABLE
|
// Try to atomically transition from IDLE to UNUSABLE
|
||||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Successfully acquired
|
// Successfully acquired
|
||||||
acquired = true
|
acquired = true
|
||||||
@@ -185,7 +185,7 @@ func TestConcurrentReAuthAndUsage(t *testing.T) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for i := 0; i < 50; i++ {
|
for i := 0; i < 50; i++ {
|
||||||
// Try to acquire for re-auth
|
// Try to acquire for re-auth
|
||||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
reAuthCount.Add(1)
|
reAuthCount.Add(1)
|
||||||
// Simulate re-auth work
|
// Simulate re-auth work
|
||||||
@@ -228,7 +228,7 @@ func TestReAuthRespectsClosed(t *testing.T) {
|
|||||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
|
||||||
// Try to transition to UNUSABLE - should fail
|
// Try to transition to UNUSABLE - should fail
|
||||||
err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
_, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when trying to transition CLOSED → UNUSABLE, but got none")
|
t.Error("Expected error when trying to transition CLOSED → UNUSABLE, but got none")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
|||||||
if new {
|
if new {
|
||||||
// Trying to make usable - transition from UNUSABLE to IDLE
|
// Trying to make usable - transition from UNUSABLE to IDLE
|
||||||
// This should only work from UNUSABLE or INITIALIZING states
|
// This should only work from UNUSABLE or INITIALIZING states
|
||||||
err := cn.stateMachine.TryTransition(
|
_, err := cn.stateMachine.TryTransition(
|
||||||
[]ConnState{StateInitializing, StateUnusable},
|
[]ConnState{StateInitializing, StateUnusable},
|
||||||
StateIdle,
|
StateIdle,
|
||||||
)
|
)
|
||||||
@@ -175,7 +175,7 @@ func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
|||||||
} else {
|
} else {
|
||||||
// Trying to make unusable - transition from IDLE to UNUSABLE
|
// Trying to make unusable - transition from IDLE to UNUSABLE
|
||||||
// This is typically for acquiring the connection for background operations
|
// This is typically for acquiring the connection for background operations
|
||||||
err := cn.stateMachine.TryTransition(
|
_, err := cn.stateMachine.TryTransition(
|
||||||
[]ConnState{StateIdle},
|
[]ConnState{StateIdle},
|
||||||
StateUnusable,
|
StateUnusable,
|
||||||
)
|
)
|
||||||
@@ -200,10 +200,13 @@ func (cn *Conn) IsUsable() bool {
|
|||||||
|
|
||||||
// SetUsable sets the usable flag for the connection (lock-free).
|
// SetUsable sets the usable flag for the connection (lock-free).
|
||||||
//
|
//
|
||||||
|
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||||
|
// This method is kept for backwards compatibility.
|
||||||
|
//
|
||||||
// This should be called to mark a connection as usable after initialization or
|
// This should be called to mark a connection as usable after initialization or
|
||||||
// to release it after a background operation completes.
|
// to release it after a background operation completes.
|
||||||
//
|
//
|
||||||
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
|
// Prefer CompareAndSwapUsed() when acquiring exclusive access to avoid race conditions.
|
||||||
func (cn *Conn) SetUsable(usable bool) {
|
func (cn *Conn) SetUsable(usable bool) {
|
||||||
if usable {
|
if usable {
|
||||||
// Transition to IDLE state (ready to be acquired)
|
// Transition to IDLE state (ready to be acquired)
|
||||||
@@ -226,6 +229,9 @@ func (cn *Conn) IsInited() bool {
|
|||||||
|
|
||||||
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
||||||
//
|
//
|
||||||
|
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||||
|
// This method is kept for backwards compatibility.
|
||||||
|
//
|
||||||
// This is the preferred method for acquiring a connection from the pool, as it
|
// This is the preferred method for acquiring a connection from the pool, as it
|
||||||
// ensures that only one goroutine marks the connection as used.
|
// ensures that only one goroutine marks the connection as used.
|
||||||
//
|
//
|
||||||
@@ -242,17 +248,20 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
|||||||
|
|
||||||
if !old && new {
|
if !old && new {
|
||||||
// Acquiring: IDLE → IN_USE
|
// Acquiring: IDLE → IN_USE
|
||||||
err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse)
|
_, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||||
return err == nil
|
return err == nil
|
||||||
} else {
|
} else {
|
||||||
// Releasing: IN_USE → IDLE
|
// Releasing: IN_USE → IDLE
|
||||||
err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle)
|
_, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUsed returns true if the connection is currently in use (lock-free).
|
// IsUsed returns true if the connection is currently in use (lock-free).
|
||||||
//
|
//
|
||||||
|
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
|
||||||
|
// This method is kept for backwards compatibility.
|
||||||
|
//
|
||||||
// A connection is "used" when it has been retrieved from the pool and is
|
// A connection is "used" when it has been retrieved from the pool and is
|
||||||
// actively processing a command. Background operations (like re-auth) should
|
// actively processing a command. Background operations (like re-auth) should
|
||||||
// wait until the connection is not used before executing commands.
|
// wait until the connection is not used before executing commands.
|
||||||
@@ -526,28 +535,32 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err
|
|||||||
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
|
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
|
||||||
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
|
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
|
||||||
// If another goroutine is initializing, we'll wait for it to finish
|
// If another goroutine is initializing, we'll wait for it to finish
|
||||||
err := cn.stateMachine.AwaitAndTransition(
|
// Add 1ms timeout to prevent indefinite blocking
|
||||||
ctx,
|
waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
finalState, err := cn.stateMachine.AwaitAndTransition(
|
||||||
|
waitCtx,
|
||||||
[]ConnState{StateCreated, StateIdle, StateUnusable},
|
[]ConnState{StateCreated, StateIdle, StateUnusable},
|
||||||
StateInitializing,
|
StateInitializing,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot initialize connection from state %s: %w", cn.stateMachine.GetState(), err)
|
return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace the underlying connection
|
// Replace the underlying connection
|
||||||
cn.SetNetConn(netConn)
|
cn.SetNetConn(netConn)
|
||||||
|
|
||||||
// Execute initialization
|
// Execute initialization
|
||||||
|
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
|
||||||
|
// or CLOSED on failure. We don't need to do it here.
|
||||||
initErr := cn.ExecuteInitConn(ctx)
|
initErr := cn.ExecuteInitConn(ctx)
|
||||||
if initErr != nil {
|
if initErr != nil {
|
||||||
// Initialization failed - transition to closed
|
// ExecuteInitConn already transitioned to CLOSED, just return the error
|
||||||
cn.stateMachine.Transition(StateClosed)
|
|
||||||
return initErr
|
return initErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialization succeeded - transition to IDLE (ready to be acquired)
|
// ExecuteInitConn already transitioned to IDLE
|
||||||
cn.stateMachine.Transition(StateIdle)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,7 +590,8 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mark connection as unusable while queued for handoff
|
// Mark connection as unusable while queued for handoff
|
||||||
cn.SetUsable(false)
|
// Use state machine directly instead of deprecated SetUsable
|
||||||
|
cn.stateMachine.Transition(StateUnusable)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -606,7 +620,8 @@ func (cn *Conn) ClearHandoffState() {
|
|||||||
cn.handoffRetriesAtomic.Store(0)
|
cn.handoffRetriesAtomic.Store(0)
|
||||||
|
|
||||||
// Mark connection as usable again
|
// Mark connection as usable again
|
||||||
cn.SetUsable(true)
|
// Use state machine directly instead of deprecated SetUsable
|
||||||
|
cn.stateMachine.Transition(StateIdle)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasBufferedData safely checks if the connection has buffered data.
|
// HasBufferedData safely checks if the connection has buffered data.
|
||||||
|
|||||||
@@ -113,19 +113,20 @@ func (sm *ConnStateMachine) GetState() ConnState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TryTransition attempts an immediate state transition without waiting.
|
// TryTransition attempts an immediate state transition without waiting.
|
||||||
// Returns ErrInvalidStateTransition if current state is not in validFromStates.
|
// Returns the current state after the transition attempt and an error if the transition failed.
|
||||||
|
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||||
// This is faster than AwaitAndTransition when you don't need to wait.
|
// This is faster than AwaitAndTransition when you don't need to wait.
|
||||||
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
|
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
|
||||||
// This method does NOT wait - it fails immediately if the transition cannot be performed.
|
// This method does NOT wait - it fails immediately if the transition cannot be performed.
|
||||||
//
|
//
|
||||||
// Performance: Zero allocations on success path (hot path).
|
// Performance: Zero allocations on success path (hot path).
|
||||||
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) error {
|
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
|
||||||
// Try each valid from state with CAS
|
// Try each valid from state with CAS
|
||||||
// This ensures only ONE goroutine can successfully transition at a time
|
// This ensures only ONE goroutine can successfully transition at a time
|
||||||
for _, fromState := range validFromStates {
|
for _, fromState := range validFromStates {
|
||||||
// Fast path: check if we're already in target state
|
// Fast path: check if we're already in target state
|
||||||
if fromState == targetState && sm.GetState() == targetState {
|
if fromState == targetState && sm.GetState() == targetState {
|
||||||
return nil
|
return targetState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to atomically swap from fromState to targetState
|
// Try to atomically swap from fromState to targetState
|
||||||
@@ -134,14 +135,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta
|
|||||||
// Success! We transitioned atomically
|
// Success! We transitioned atomically
|
||||||
// Notify any waiters
|
// Notify any waiters
|
||||||
sm.notifyWaiters()
|
sm.notifyWaiters()
|
||||||
return nil
|
return targetState, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// All CAS attempts failed - state is not valid for this transition
|
// All CAS attempts failed - state is not valid for this transition
|
||||||
|
// Return the current state so caller can decide what to do
|
||||||
// Note: This error path allocates, but it's the exceptional case
|
// Note: This error path allocates, but it's the exceptional case
|
||||||
currentState := sm.GetState()
|
currentState := sm.GetState()
|
||||||
return fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
||||||
ErrInvalidStateTransition, currentState, targetState, validFromStates)
|
ErrInvalidStateTransition, currentState, targetState, validFromStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +157,8 @@ func (sm *ConnStateMachine) Transition(targetState ConnState) {
|
|||||||
|
|
||||||
// AwaitAndTransition waits for the connection to reach one of the valid states,
|
// AwaitAndTransition waits for the connection to reach one of the valid states,
|
||||||
// then atomically transitions to the target state.
|
// then atomically transitions to the target state.
|
||||||
|
// Returns the current state after the transition attempt and an error if the operation failed.
|
||||||
|
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||||
// Returns error if timeout expires or context is cancelled.
|
// Returns error if timeout expires or context is cancelled.
|
||||||
//
|
//
|
||||||
// This method implements FIFO fairness - the first caller to wait gets priority
|
// This method implements FIFO fairness - the first caller to wait gets priority
|
||||||
@@ -167,19 +171,19 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
validFromStates []ConnState,
|
validFromStates []ConnState,
|
||||||
targetState ConnState,
|
targetState ConnState,
|
||||||
) error {
|
) (ConnState, error) {
|
||||||
// Fast path: try immediate transition with CAS to prevent race conditions
|
// Fast path: try immediate transition with CAS to prevent race conditions
|
||||||
for _, fromState := range validFromStates {
|
for _, fromState := range validFromStates {
|
||||||
// Check if we're already in target state
|
// Check if we're already in target state
|
||||||
if fromState == targetState && sm.GetState() == targetState {
|
if fromState == targetState && sm.GetState() == targetState {
|
||||||
return nil
|
return targetState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to atomically swap from fromState to targetState
|
// Try to atomically swap from fromState to targetState
|
||||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||||
// Success! We transitioned atomically
|
// Success! We transitioned atomically
|
||||||
sm.notifyWaiters()
|
sm.notifyWaiters()
|
||||||
return nil
|
return targetState, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +192,7 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
|||||||
|
|
||||||
// Check if closed
|
// Check if closed
|
||||||
if currentState == StateClosed {
|
if currentState == StateClosed {
|
||||||
return ErrStateMachineClosed
|
return currentState, ErrStateMachineClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slow path: need to wait for state change
|
// Slow path: need to wait for state change
|
||||||
@@ -216,9 +220,10 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
|||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
sm.waiters.Remove(elem)
|
sm.waiters.Remove(elem)
|
||||||
sm.mu.Unlock()
|
sm.mu.Unlock()
|
||||||
return ctx.Err()
|
return sm.GetState(), ctx.Err()
|
||||||
case err := <-w.done:
|
case err := <-w.done:
|
||||||
return err
|
// Transition completed (or failed)
|
||||||
|
return sm.GetState(), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,27 +240,35 @@ func (sm *ConnStateMachine) notifyWaiters() {
|
|||||||
// Process waiters in FIFO order until no more can be processed
|
// Process waiters in FIFO order until no more can be processed
|
||||||
// We loop instead of recursing to avoid stack overflow and mutex issues
|
// We loop instead of recursing to avoid stack overflow and mutex issues
|
||||||
for {
|
for {
|
||||||
currentState := sm.GetState()
|
|
||||||
processed := false
|
processed := false
|
||||||
|
|
||||||
// Find the first waiter that can proceed
|
// Find the first waiter that can proceed
|
||||||
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
|
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
|
||||||
w := elem.Value.(*waiter)
|
w := elem.Value.(*waiter)
|
||||||
|
|
||||||
|
// Read current state inside the loop to get the latest value
|
||||||
|
currentState := sm.GetState()
|
||||||
|
|
||||||
// Check if current state is valid for this waiter
|
// Check if current state is valid for this waiter
|
||||||
if _, valid := w.validStates[currentState]; valid {
|
if _, valid := w.validStates[currentState]; valid {
|
||||||
// Remove from queue
|
// Remove from queue first
|
||||||
sm.waiters.Remove(elem)
|
sm.waiters.Remove(elem)
|
||||||
|
|
||||||
// Perform transition
|
// Use CAS to ensure state hasn't changed since we checked
|
||||||
sm.state.Store(uint32(w.targetState))
|
// This prevents race condition where another thread changes state
|
||||||
|
// between our check and our transition
|
||||||
// Notify waiter (non-blocking due to buffered channel)
|
if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
|
||||||
w.done <- nil
|
// Successfully transitioned - notify waiter
|
||||||
|
w.done <- nil
|
||||||
// Mark that we processed a waiter and break to check for more
|
processed = true
|
||||||
processed = true
|
break
|
||||||
break
|
} else {
|
||||||
|
// State changed - re-add waiter to front of queue and retry
|
||||||
|
sm.waiters.PushFront(w)
|
||||||
|
// Continue to next iteration to re-read state
|
||||||
|
processed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestConnStateMachine_TryTransition(t *testing.T) {
|
|||||||
sm := NewConnStateMachine()
|
sm := NewConnStateMachine()
|
||||||
sm.Transition(tt.initialState)
|
sm.Transition(tt.initialState)
|
||||||
|
|
||||||
err := sm.TryTransition(tt.validStates, tt.targetState)
|
_, err := sm.TryTransition(tt.validStates, tt.targetState)
|
||||||
|
|
||||||
if tt.expectError && err == nil {
|
if tt.expectError && err == nil {
|
||||||
t.Error("expected error but got none")
|
t.Error("expected error but got none")
|
||||||
@@ -99,7 +99,7 @@ func TestConnStateMachine_AwaitAndTransition_FastPath(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// Fast path: already in valid state
|
// Fast path: already in valid state
|
||||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -117,7 +117,7 @@ func TestConnStateMachine_AwaitAndTransition_Timeout(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Wait for a state that will never come
|
// Wait for a state that will never come
|
||||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected timeout error but got none")
|
t.Error("expected timeout error but got none")
|
||||||
}
|
}
|
||||||
@@ -150,7 +150,7 @@ func TestConnStateMachine_AwaitAndTransition_FIFO(t *testing.T) {
|
|||||||
startBarrier.Wait()
|
startBarrier.Wait()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle)
|
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("waiter %d got error: %v", waiterID, err)
|
t.Errorf("waiter %d got error: %v", waiterID, err)
|
||||||
return
|
return
|
||||||
@@ -206,7 +206,7 @@ func TestConnStateMachine_ConcurrentAccess(t *testing.T) {
|
|||||||
|
|
||||||
for j := 0; j < numIterations; j++ {
|
for j := 0; j < numIterations; j++ {
|
||||||
// Try to transition from READY to REAUTH_IN_PROGRESS
|
// Try to transition from READY to REAUTH_IN_PROGRESS
|
||||||
err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
_, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
successCount.Add(1)
|
successCount.Add(1)
|
||||||
// Transition back to READY
|
// Transition back to READY
|
||||||
@@ -287,7 +287,7 @@ func TestConnStateMachine_PreventsConcurrentInitialization(t *testing.T) {
|
|||||||
startBarrier.Wait()
|
startBarrier.Wait()
|
||||||
|
|
||||||
// Try to transition to INITIALIZING
|
// Try to transition to INITIALIZING
|
||||||
err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing)
|
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
successCount.Add(1)
|
successCount.Add(1)
|
||||||
|
|
||||||
@@ -353,7 +353,7 @@ func TestConnStateMachine_AwaitAndTransitionWaitsForInitialization(t *testing.T)
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// Try to transition to INITIALIZING - should wait if another is initializing
|
// Try to transition to INITIALIZING - should wait if another is initializing
|
||||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||||
return
|
return
|
||||||
@@ -419,7 +419,7 @@ func TestConnStateMachine_FIFOOrdering(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// This should queue in FIFO order
|
// This should queue in FIFO order
|
||||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||||
return
|
return
|
||||||
@@ -482,7 +482,7 @@ func TestConnStateMachine_FIFOWithFastPath(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// This might use fast path (CAS) or slow path (queue)
|
// This might use fast path (CAS) or slow path (queue)
|
||||||
err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||||
return
|
return
|
||||||
@@ -528,7 +528,7 @@ func BenchmarkConnStateMachine_TryTransition(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
_, _ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||||
sm.Transition(StateIdle)
|
sm.Transition(StateIdle)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -543,7 +543,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
|||||||
sm.Transition(StateIdle)
|
sm.Transition(StateIdle)
|
||||||
|
|
||||||
// Test IDLE → IN_USE transition
|
// Test IDLE → IN_USE transition
|
||||||
err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to transition from IDLE to IN_USE: %v", err)
|
t.Errorf("failed to transition from IDLE to IN_USE: %v", err)
|
||||||
}
|
}
|
||||||
@@ -552,7 +552,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test IN_USE → IDLE transition
|
// Test IN_USE → IDLE transition
|
||||||
err = sm.TryTransition([]ConnState{StateInUse}, StateIdle)
|
_, err = sm.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to transition from IN_USE to IDLE: %v", err)
|
t.Errorf("failed to transition from IN_USE to IDLE: %v", err)
|
||||||
}
|
}
|
||||||
@@ -570,7 +570,7 @@ func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
successCount.Add(1)
|
successCount.Add(1)
|
||||||
}
|
}
|
||||||
@@ -641,7 +641,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
|||||||
sm.Transition(StateIdle)
|
sm.Transition(StateIdle)
|
||||||
|
|
||||||
// Test IDLE → UNUSABLE transition (for background operations)
|
// Test IDLE → UNUSABLE transition (for background operations)
|
||||||
err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
_, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err)
|
t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err)
|
||||||
}
|
}
|
||||||
@@ -650,7 +650,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test UNUSABLE → IDLE transition (after background operation completes)
|
// Test UNUSABLE → IDLE transition (after background operation completes)
|
||||||
err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle)
|
_, err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err)
|
t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err)
|
||||||
}
|
}
|
||||||
@@ -661,7 +661,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
|||||||
// Test that we can transition from IN_USE to UNUSABLE if needed
|
// Test that we can transition from IN_USE to UNUSABLE if needed
|
||||||
// (e.g., for urgent handoff while connection is in use)
|
// (e.g., for urgent handoff while connection is in use)
|
||||||
sm.Transition(StateInUse)
|
sm.Transition(StateInUse)
|
||||||
err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable)
|
_, err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err)
|
t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err)
|
||||||
}
|
}
|
||||||
@@ -672,7 +672,7 @@ func TestConnStateMachine_UnusableState(t *testing.T) {
|
|||||||
// Test UNUSABLE → INITIALIZING transition (for handoff)
|
// Test UNUSABLE → INITIALIZING transition (for handoff)
|
||||||
sm.Transition(StateIdle)
|
sm.Transition(StateIdle)
|
||||||
sm.Transition(StateUnusable)
|
sm.Transition(StateUnusable)
|
||||||
err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing)
|
_, err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err)
|
t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -244,9 +244,9 @@ func (p *ConnPool) addIdleConn() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark connection as usable after successful creation
|
// NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
|
||||||
// This is essential for normal pool operations
|
// when first acquired from the pool. Do NOT transition to IDLE here - that happens
|
||||||
cn.SetUsable(true)
|
// after initialization completes.
|
||||||
|
|
||||||
p.connsMu.Lock()
|
p.connsMu.Lock()
|
||||||
defer p.connsMu.Unlock()
|
defer p.connsMu.Unlock()
|
||||||
@@ -286,9 +286,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark connection as usable after successful creation
|
// NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
|
||||||
// This is essential for normal pool operations
|
// when first used. Do NOT transition to IDLE here - that happens after initialization completes.
|
||||||
cn.SetUsable(true)
|
// The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success)
|
||||||
|
|
||||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||||
_ = cn.Close()
|
_ = cn.Close()
|
||||||
@@ -584,15 +584,17 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
|||||||
}
|
}
|
||||||
attempts++
|
attempts++
|
||||||
|
|
||||||
if cn.CompareAndSwapUsed(false, true) {
|
// Try to atomically transition to IN_USE using state machine
|
||||||
if cn.IsUsable() {
|
// Accept both CREATED (uninitialized) and IDLE (initialized) states
|
||||||
p.idleConnsLen.Add(-1)
|
_, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse)
|
||||||
break
|
if err == nil {
|
||||||
}
|
// Successfully acquired the connection
|
||||||
cn.SetUsed(false)
|
p.idleConnsLen.Add(-1)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection is not usable, put it back in the pool
|
// Connection is not in a valid state (might be UNUSABLE for re-auth, INITIALIZING, etc.)
|
||||||
|
// Put it back in the pool
|
||||||
if p.cfg.PoolFIFO {
|
if p.cfg.PoolFIFO {
|
||||||
// FIFO: put at end (will be picked up last since we pop from front)
|
// FIFO: put at end (will be picked up last since we pop from front)
|
||||||
p.idleConns = append(p.idleConns, cn)
|
p.idleConns = append(p.idleConns, cn)
|
||||||
@@ -661,6 +663,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
|||||||
var shouldCloseConn bool
|
var shouldCloseConn bool
|
||||||
|
|
||||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||||
|
// Transition to IDLE state BEFORE adding to pool
|
||||||
|
// This prevents race condition where another goroutine could acquire
|
||||||
|
// a connection that's still in IN_USE state
|
||||||
|
cn.GetStateMachine().Transition(StateIdle)
|
||||||
|
|
||||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||||
// put them at the opposite end of the queue
|
// put them at the opposite end of the queue
|
||||||
if !cn.IsUsable() {
|
if !cn.IsUsable() {
|
||||||
@@ -684,11 +691,6 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
|||||||
shouldCloseConn = true
|
shouldCloseConn = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the connection is not going to be closed, mark it as not used
|
|
||||||
if !shouldCloseConn {
|
|
||||||
cn.SetUsed(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.freeTurn()
|
p.freeTurn()
|
||||||
|
|
||||||
if shouldCloseConn {
|
if shouldCloseConn {
|
||||||
|
|||||||
@@ -44,6 +44,13 @@ func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
|
|||||||
if p.cn == nil {
|
if p.cn == nil {
|
||||||
return nil, ErrClosed
|
return nil, ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
|
||||||
|
// - During initialization (connection is in INITIALIZING state)
|
||||||
|
// - During re-authentication (connection is in UNUSABLE state)
|
||||||
|
// - For transactions (connection might be in various states)
|
||||||
|
// We use SetUsed() which forces the transition, rather than TryTransition() which
|
||||||
|
// would fail if the connection is not in IDLE/CREATED state.
|
||||||
p.cn.SetUsed(true)
|
p.cn.SetUsed(true)
|
||||||
p.cn.SetUsedAt(time.Now())
|
p.cn.SetUsedAt(time.Now())
|
||||||
return p.cn, nil
|
return p.cn, nil
|
||||||
|
|||||||
133
redis.go
133
redis.go
@@ -366,28 +366,82 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||||
if !cn.Inited.CompareAndSwap(false, true) {
|
// This function is called in two scenarios:
|
||||||
|
// 1. First-time init: Connection is in CREATED state (from pool.Get())
|
||||||
|
// - We need to transition CREATED → INITIALIZING and do the initialization
|
||||||
|
// - If another goroutine is already initializing, we WAIT for it to finish
|
||||||
|
// 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn())
|
||||||
|
// - We're already in INITIALIZING, so just proceed with initialization
|
||||||
|
|
||||||
|
currentState := cn.GetStateMachine().GetState()
|
||||||
|
|
||||||
|
// Fast path: Check if already initialized (IDLE or IN_USE)
|
||||||
|
if currentState == pool.StateIdle || currentState == pool.StateInUse {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
|
// If in CREATED state, try to transition to INITIALIZING
|
||||||
|
if currentState == pool.StateCreated {
|
||||||
|
finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing)
|
||||||
|
if err != nil {
|
||||||
|
// Another goroutine is initializing or connection is in unexpected state
|
||||||
|
// Check what state we're in now
|
||||||
|
if finalState == pool.StateIdle || finalState == pool.StateInUse {
|
||||||
|
// Already initialized by another goroutine
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalState == pool.StateInitializing {
|
||||||
|
// Another goroutine is initializing - WAIT for it to complete
|
||||||
|
// Use AwaitAndTransition to wait for IDLE or IN_USE state
|
||||||
|
// Add 1ms timeout to prevent indefinite blocking
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
finalState, err := cn.GetStateMachine().AwaitAndTransition(
|
||||||
|
waitCtx,
|
||||||
|
[]pool.ConnState{pool.StateIdle, pool.StateInUse},
|
||||||
|
pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op)
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Verify we're now initialized
|
||||||
|
if finalState == pool.StateIdle || finalState == pool.StateInUse {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Unexpected state after waiting
|
||||||
|
return fmt.Errorf("connection in unexpected state after initialization: %s", finalState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unexpected state (CLOSED, UNUSABLE, etc.)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we're in INITIALIZING state and we own the initialization
|
||||||
|
// If we fail, we must transition to CLOSED
|
||||||
|
var initErr error
|
||||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||||
|
|
||||||
username, password := "", ""
|
username, password := "", ""
|
||||||
if c.opt.StreamingCredentialsProvider != nil {
|
if c.opt.StreamingCredentialsProvider != nil {
|
||||||
credListener, err := c.streamingCredentialsManager.Listener(
|
credListener, initErr := c.streamingCredentialsManager.Listener(
|
||||||
cn,
|
cn,
|
||||||
c.reAuthConnection(),
|
c.reAuthConnection(),
|
||||||
c.onAuthenticationErr(),
|
c.onAuthenticationErr(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if initErr != nil {
|
||||||
return fmt.Errorf("failed to create credentials listener: %w", err)
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return fmt.Errorf("failed to create credentials listener: %w", initErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider.
|
||||||
Subscribe(credListener)
|
Subscribe(credListener)
|
||||||
if err != nil {
|
if initErr != nil {
|
||||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||||
@@ -395,9 +449,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
|
|
||||||
username, password = credentials.BasicAuth()
|
username, password = credentials.BasicAuth()
|
||||||
} else if c.opt.CredentialsProviderContext != nil {
|
} else if c.opt.CredentialsProviderContext != nil {
|
||||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
username, password, initErr = c.opt.CredentialsProviderContext(ctx)
|
||||||
if err != nil {
|
if initErr != nil {
|
||||||
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return fmt.Errorf("failed to get credentials from context provider: %w", initErr)
|
||||||
}
|
}
|
||||||
} else if c.opt.CredentialsProvider != nil {
|
} else if c.opt.CredentialsProvider != nil {
|
||||||
username, password = c.opt.CredentialsProvider()
|
username, password = c.opt.CredentialsProvider()
|
||||||
@@ -407,9 +462,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
|
|
||||||
// for redis-server versions that do not support the HELLO command,
|
// for redis-server versions that do not support the HELLO command,
|
||||||
// RESP2 will continue to be used.
|
// RESP2 will continue to be used.
|
||||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil {
|
||||||
// Authentication successful with HELLO command
|
// Authentication successful with HELLO command
|
||||||
} else if !isRedisError(err) {
|
} else if !isRedisError(initErr) {
|
||||||
// When the server responds with the RESP protocol and the result is not a normal
|
// When the server responds with the RESP protocol and the result is not a normal
|
||||||
// execution result of the HELLO command, we consider it to be an indication that
|
// execution result of the HELLO command, we consider it to be an indication that
|
||||||
// the server does not support the HELLO command.
|
// the server does not support the HELLO command.
|
||||||
@@ -417,20 +472,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
||||||
// with different error string results for unsupported commands, making it
|
// with different error string results for unsupported commands, making it
|
||||||
// difficult to rely on error strings to determine all results.
|
// difficult to rely on error strings to determine all results.
|
||||||
return err
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return initErr
|
||||||
} else if password != "" {
|
} else if password != "" {
|
||||||
// Try legacy AUTH command if HELLO failed
|
// Try legacy AUTH command if HELLO failed
|
||||||
if username != "" {
|
if username != "" {
|
||||||
err = conn.AuthACL(ctx, username, password).Err()
|
initErr = conn.AuthACL(ctx, username, password).Err()
|
||||||
} else {
|
} else {
|
||||||
err = conn.Auth(ctx, password).Err()
|
initErr = conn.Auth(ctx, password).Err()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if initErr != nil {
|
||||||
return fmt.Errorf("failed to authenticate: %w", err)
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return fmt.Errorf("failed to authenticate: %w", initErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
_, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||||
if c.opt.DB > 0 {
|
if c.opt.DB > 0 {
|
||||||
pipe.Select(ctx, c.opt.DB)
|
pipe.Select(ctx, c.opt.DB)
|
||||||
}
|
}
|
||||||
@@ -445,8 +502,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if initErr != nil {
|
||||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return fmt.Errorf("failed to initialize connection options: %w", initErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable maintnotifications if maintnotifications are configured
|
// Enable maintnotifications if maintnotifications are configured
|
||||||
@@ -465,6 +523,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
if maintNotifHandshakeErr != nil {
|
if maintNotifHandshakeErr != nil {
|
||||||
if !isRedisError(maintNotifHandshakeErr) {
|
if !isRedisError(maintNotifHandshakeErr) {
|
||||||
// if not redis error, fail the connection
|
// if not redis error, fail the connection
|
||||||
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
return maintNotifHandshakeErr
|
return maintNotifHandshakeErr
|
||||||
}
|
}
|
||||||
c.optLock.Lock()
|
c.optLock.Lock()
|
||||||
@@ -473,15 +532,16 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
case maintnotifications.ModeEnabled:
|
case maintnotifications.ModeEnabled:
|
||||||
// enabled mode, fail the connection
|
// enabled mode, fail the connection
|
||||||
c.optLock.Unlock()
|
c.optLock.Unlock()
|
||||||
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
|
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
|
||||||
default: // will handle auto and any other
|
default: // will handle auto and any other
|
||||||
internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
|
internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
|
||||||
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
|
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
|
||||||
c.optLock.Unlock()
|
c.optLock.Unlock()
|
||||||
// auto mode, disable maintnotifications and continue
|
// auto mode, disable maintnotifications and continue
|
||||||
if err := c.disableMaintNotificationsUpgrades(); err != nil {
|
if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil {
|
||||||
// Log error but continue - auto mode should be resilient
|
// Log error but continue - auto mode should be resilient
|
||||||
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err)
|
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -505,22 +565,31 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
|
p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
|
||||||
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
||||||
// out of order responses later on.
|
// out of order responses later on.
|
||||||
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
|
if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) {
|
||||||
return err
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return initErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mark the connection as usable and inited
|
|
||||||
// once returned to the pool as idle, this connection can be used by other clients
|
|
||||||
cn.SetUsable(true)
|
|
||||||
cn.SetUsed(false)
|
|
||||||
cn.Inited.Store(true)
|
|
||||||
|
|
||||||
// Set the connection initialization function for potential reconnections
|
// Set the connection initialization function for potential reconnections
|
||||||
|
// This must be set before transitioning to IDLE so that handoff/reauth can use it
|
||||||
cn.SetInitConnFunc(c.createInitConnFunc())
|
cn.SetInitConnFunc(c.createInitConnFunc())
|
||||||
|
|
||||||
|
// Initialization succeeded - transition to IDLE state
|
||||||
|
// This marks the connection as initialized and ready for use
|
||||||
|
// NOTE: The connection is still owned by the calling goroutine at this point
|
||||||
|
// and won't be available to other goroutines until it's Put() back into the pool
|
||||||
|
cn.GetStateMachine().Transition(pool.StateIdle)
|
||||||
|
|
||||||
|
// Call OnConnect hook if configured
|
||||||
|
// The connection is in IDLE state but still owned by this goroutine
|
||||||
|
// If OnConnect needs to send commands, it can use the connection safely
|
||||||
if c.opt.OnConnect != nil {
|
if c.opt.OnConnect != nil {
|
||||||
return c.opt.OnConnect(ctx, conn)
|
if initErr = c.opt.OnConnect(ctx, conn); initErr != nil {
|
||||||
|
// OnConnect failed - transition to closed
|
||||||
|
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||||
|
return initErr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user