diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index 7318589c..4bffb2c4 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -10,10 +10,13 @@ import ( type ReAuthPoolHook struct { // conn id -> func() reauth func with error handling - shouldReAuth map[uint64]func(error) - lock sync.RWMutex - workers chan struct{} + shouldReAuth map[uint64]func(error) + shouldReAuthLock sync.RWMutex + workers chan struct{} reAuthTimeout time.Duration + // conn id -> bool + scheduledReAuth map[uint64]bool + scheduledLock sync.RWMutex } func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { @@ -24,52 +27,65 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo } return &ReAuthPoolHook{ - shouldReAuth: make(map[uint64]func(error)), - workers: workers, - reAuthTimeout: reAuthTimeout, + shouldReAuth: make(map[uint64]func(error)), + scheduledReAuth: make(map[uint64]bool), + workers: workers, + reAuthTimeout: reAuthTimeout, } } func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { - r.lock.Lock() - defer r.lock.Unlock() + r.shouldReAuthLock.Lock() + defer r.shouldReAuthLock.Unlock() r.shouldReAuth[connID] = reAuthFn } func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) { - r.lock.Lock() - defer r.lock.Unlock() + r.shouldReAuthLock.Lock() + defer r.shouldReAuthLock.Unlock() delete(r.shouldReAuth, connID) } -func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) error { +func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { + r.shouldReAuthLock.RLock() + _, ok := r.shouldReAuth[conn.GetID()] + r.shouldReAuthLock.RUnlock() // This connection was marked for reauth while in the pool, - // so we need to reauth it before returning it to the user. - r.lock.RLock() - reAuthFn, ok := r.shouldReAuth[conn.GetID()] - r.lock.RUnlock() + // reject the connection if ok { - // Clear the mark immediately to prevent duplicate reauth attempts - r.ClearReAuthMark(conn.GetID()) - reAuthFn(nil) + // simply reject the connection, it will be re-authenticated in OnPut + return false, nil } - return nil + r.scheduledLock.RLock() + hasScheduled, ok := r.scheduledReAuth[conn.GetID()] + r.scheduledLock.RUnlock() + // has scheduled reauth, reject the connection + if ok && hasScheduled { + // simply reject the connection, it will be re-authenticated in OnPut + return false, nil + } + return true, nil } func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { // Check if reauth is needed and get the function with proper locking - r.lock.RLock() + r.shouldReAuthLock.RLock() reAuthFn, ok := r.shouldReAuth[conn.GetID()] - r.lock.RUnlock() + r.shouldReAuthLock.RUnlock() if ok { + r.scheduledLock.Lock() + r.scheduledReAuth[conn.GetID()] = true + r.scheduledLock.Unlock() // Clear the mark immediately to prevent duplicate reauth attempts r.ClearReAuthMark(conn.GetID()) - go func() { <-r.workers defer func() { + r.scheduledLock.Lock() + delete(r.scheduledReAuth, conn.GetID()) + r.scheduledLock.Unlock() r.workers <- struct{}{} }() @@ -103,6 +119,12 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, } func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) { + r.scheduledLock.Lock() + delete(r.scheduledReAuth, conn.GetID()) + r.scheduledLock.Unlock() + r.shouldReAuthLock.Lock() + delete(r.shouldReAuth, conn.GetID()) + r.shouldReAuthLock.Unlock() r.ClearReAuthMark(conn.GetID()) } diff --git a/internal/pool/hooks.go b/internal/pool/hooks.go index 205f25fb..cb988b9f 100644 --- a/internal/pool/hooks.go +++ b/internal/pool/hooks.go @@ -9,9 +9,13 @@ import ( type PoolHook interface { // OnGet is called when a connection is retrieved from the pool. // It can modify the connection or return an error to prevent its use. + // The accept flag can be used to prevent the connection from being used. + // On Accept = false the connection is rejected and returned to the pool. + // The error can be used to prevent the connection from being used and returned to the pool. + // On Errors, the connection is removed from the pool. // It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool) // The flag can be used for gathering metrics on pool hit/miss ratio. - OnGet(ctx context.Context, conn *Conn, isNewConn bool) error + OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error) // OnPut is called when a connection is returned to the pool. // It returns whether the connection should be pooled and whether it should be removed. @@ -60,16 +64,21 @@ func (phm *PoolHookManager) RemoveHook(hook PoolHook) { // ProcessOnGet calls all OnGet hooks in order. // If any hook returns an error, processing stops and the error is returned. -func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error { +func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) { phm.hooksMu.RLock() defer phm.hooksMu.RUnlock() for _, hook := range phm.hooks { - if err := hook.OnGet(ctx, conn, isNewConn); err != nil { - return err + acceptConn, err := hook.OnGet(ctx, conn, isNewConn) + if err != nil { + return false, err + } + + if !acceptConn { + return false, nil } } - return nil + return true, nil } // ProcessOnPut calls all OnPut hooks in order. diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go index 2a14c9c2..02597225 100644 --- a/internal/pool/hooks_test.go +++ b/internal/pool/hooks_test.go @@ -56,10 +56,13 @@ func TestPoolHookManager(t *testing.T) { ctx := context.Background() conn := &Conn{} // Mock connection - err := manager.ProcessOnGet(ctx, conn, false) + accept, err := manager.ProcessOnGet(ctx, conn, false) if err != nil { t.Errorf("ProcessOnGet should not error: %v", err) } + if !accept { + t.Error("Expected accept to be true") + } if hook1.OnGetCalled != 1 { t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled) @@ -117,10 +120,13 @@ func TestHookErrorHandling(t *testing.T) { conn := &Conn{} // Test that error stops processing - err := manager.ProcessOnGet(ctx, conn, false) + accept, err := manager.ProcessOnGet(ctx, conn, false) if err == nil { t.Error("Expected error from ProcessOnGet") } + if accept { + t.Error("Expected accept to be false") + } if errorHook.OnGetCalled != 1 { t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index eda92d55..3541b0cb 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -434,6 +434,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { now := time.Now() attempts := 0 + + // get hooks manager + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + for { if attempts >= getAttempts { internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) @@ -460,17 +466,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() - if hookManager != nil { - if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil { + acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) + if err != nil { internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) - // Failed to process connection, discard it _ = p.CloseConn(cn) continue } + if !acceptConn { + internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + p.Put(ctx, cn) + cn = nil + continue + } } atomic.AddUint32(&p.stats.Hits, 1) @@ -486,14 +494,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() - if hookManager != nil { - if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil { + acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true) + // both errors and accept=false mean a hook rejected the connection + // this should not happen with a new connection, but we handle it gracefully + if err != nil || !acceptConn { // Failed to process connection, discard it - internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err) + internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accpet=%v, err=%v", newcn.GetID(), acceptConn, err) _ = p.CloseConn(newcn) return nil, err } diff --git a/maintnotifications/pool_hook.go b/maintnotifications/pool_hook.go index f8093524..9fd24b4a 100644 --- a/maintnotifications/pool_hook.go +++ b/maintnotifications/pool_hook.go @@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() { } // OnGet is called when a connection is retrieved from the pool -func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error { +func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is // in a handoff state at the moment. // Check if connection is usable (not in a handoff state) // Should not happen since the pool will not return a connection that is not usable. if !conn.IsUsable() { - return ErrConnectionMarkedForHandoff + return false, ErrConnectionMarkedForHandoff } // Check if connection is marked for handoff, which means it will be queued for handoff on put. if conn.ShouldHandoff() { - return ErrConnectionMarkedForHandoff + return false, ErrConnectionMarkedForHandoff } - return nil + return true, nil } // OnPut is called when a connection is returned to the pool diff --git a/maintnotifications/pool_hook_test.go b/maintnotifications/pool_hook_test.go index f2f4f433..51e73c3e 100644 --- a/maintnotifications/pool_hook_test.go +++ b/maintnotifications/pool_hook_test.go @@ -360,10 +360,13 @@ func TestConnectionHook(t *testing.T) { conn := createMockPoolConnection() ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should not error for normal connection: %v", err) } + if !acceptCon { + t.Error("Connection should be accepted for normal connection") + } }) t.Run("OnGetWithPendingHandoff", func(t *testing.T) { @@ -385,10 +388,13 @@ func TestConnectionHook(t *testing.T) { conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != ErrConnectionMarkedForHandoff { t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } + if acceptCon { + t.Error("Connection should not be accepted when marked for handoff") + } // Clean up processor.GetPendingMap().Delete(conn) @@ -416,10 +422,13 @@ func TestConnectionHook(t *testing.T) { // Test OnGet with pending handoff ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != ErrConnectionMarkedForHandoff { t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") } + if acceptCon { + t.Error("Should not accept connection with pending handoff") + } // Test removing from pending map and clearing handoff state processor.GetPendingMap().Delete(conn) @@ -432,10 +441,13 @@ func TestConnectionHook(t *testing.T) { conn.SetUsable(true) // Make connection usable again // Test OnGet without pending handoff - err = processor.OnGet(ctx, conn, false) + acceptCon, err = processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("Should not return error for non-pending connection: %v", err) } + if !acceptCon { + t.Error("Should accept connection without pending handoff") + } }) t.Run("EventDrivenQueueOptimization", func(t *testing.T) { @@ -628,11 +640,15 @@ func TestConnectionHook(t *testing.T) { } // OnGet should succeed for usable connection - err := processor.OnGet(ctx, conn, false) + acceptConn, err := processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should succeed for usable connection: %v", err) } + if !acceptConn { + t.Error("Connection should be accepted when usable") + } + // Mark connection for handoff if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { t.Fatalf("Failed to mark connection for handoff: %v", err) @@ -652,13 +668,17 @@ func TestConnectionHook(t *testing.T) { } // OnGet should fail for connection marked for handoff - err = processor.OnGet(ctx, conn, false) + acceptConn, err = processor.OnGet(ctx, conn, false) if err == nil { t.Error("OnGet should fail for connection marked for handoff") } + if err != ErrConnectionMarkedForHandoff { t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } + if acceptConn { + t.Error("Connection should not be accepted when marked for handoff") + } // Process the connection to trigger handoff shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) @@ -678,11 +698,15 @@ func TestConnectionHook(t *testing.T) { } // OnGet should succeed again - err = processor.OnGet(ctx, conn, false) + acceptConn, err = processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should succeed after handoff completion: %v", err) } + if !acceptConn { + t.Error("Connection should be accepted after handoff completion") + } + t.Logf("Usable flag behavior test completed successfully") })