mirror of
https://github.com/redis/go-redis.git
synced 2025-10-21 20:53:41 +03:00
be able to reject connection OnGet
This commit is contained in:
@@ -11,9 +11,12 @@ import (
|
||||
type ReAuthPoolHook struct {
|
||||
// conn id -> func() reauth func with error handling
|
||||
shouldReAuth map[uint64]func(error)
|
||||
lock sync.RWMutex
|
||||
shouldReAuthLock sync.RWMutex
|
||||
workers chan struct{}
|
||||
reAuthTimeout time.Duration
|
||||
// conn id -> bool
|
||||
scheduledReAuth map[uint64]bool
|
||||
scheduledLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||
@@ -25,6 +28,7 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
|
||||
|
||||
return &ReAuthPoolHook{
|
||||
shouldReAuth: make(map[uint64]func(error)),
|
||||
scheduledReAuth: make(map[uint64]bool),
|
||||
workers: workers,
|
||||
reAuthTimeout: reAuthTimeout,
|
||||
}
|
||||
@@ -32,44 +36,56 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.shouldReAuthLock.Lock()
|
||||
defer r.shouldReAuthLock.Unlock()
|
||||
r.shouldReAuth[connID] = reAuthFn
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.shouldReAuthLock.Lock()
|
||||
defer r.shouldReAuthLock.Unlock()
|
||||
delete(r.shouldReAuth, connID)
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) error {
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
r.shouldReAuthLock.RLock()
|
||||
_, ok := r.shouldReAuth[conn.GetID()]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
// This connection was marked for reauth while in the pool,
|
||||
// so we need to reauth it before returning it to the user.
|
||||
r.lock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
||||
r.lock.RUnlock()
|
||||
// reject the connection
|
||||
if ok {
|
||||
// Clear the mark immediately to prevent duplicate reauth attempts
|
||||
r.ClearReAuthMark(conn.GetID())
|
||||
reAuthFn(nil)
|
||||
// simply reject the connection, it will be re-authenticated in OnPut
|
||||
return false, nil
|
||||
}
|
||||
return nil
|
||||
r.scheduledLock.RLock()
|
||||
hasScheduled, ok := r.scheduledReAuth[conn.GetID()]
|
||||
r.scheduledLock.RUnlock()
|
||||
// has scheduled reauth, reject the connection
|
||||
if ok && hasScheduled {
|
||||
// simply reject the connection, it will be re-authenticated in OnPut
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
||||
// Check if reauth is needed and get the function with proper locking
|
||||
r.lock.RLock()
|
||||
r.shouldReAuthLock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
||||
r.lock.RUnlock()
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
|
||||
if ok {
|
||||
r.scheduledLock.Lock()
|
||||
r.scheduledReAuth[conn.GetID()] = true
|
||||
r.scheduledLock.Unlock()
|
||||
// Clear the mark immediately to prevent duplicate reauth attempts
|
||||
r.ClearReAuthMark(conn.GetID())
|
||||
|
||||
go func() {
|
||||
<-r.workers
|
||||
defer func() {
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, conn.GetID())
|
||||
r.scheduledLock.Unlock()
|
||||
r.workers <- struct{}{}
|
||||
}()
|
||||
|
||||
@@ -103,6 +119,12 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, conn.GetID())
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Lock()
|
||||
delete(r.shouldReAuth, conn.GetID())
|
||||
r.shouldReAuthLock.Unlock()
|
||||
r.ClearReAuthMark(conn.GetID())
|
||||
}
|
||||
|
||||
|
@@ -9,9 +9,13 @@ import (
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// The accept flag can be used to prevent the connection from being used.
|
||||
// On Accept = false the connection is rejected and returned to the pool.
|
||||
// The error can be used to prevent the connection from being used and returned to the pool.
|
||||
// On Errors, the connection is removed from the pool.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
@@ -60,16 +64,21 @@ func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
||||
return err
|
||||
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
|
@@ -56,10 +56,13 @@ func TestPoolHookManager(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn := &Conn{} // Mock connection
|
||||
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
accept, err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||
}
|
||||
if !accept {
|
||||
t.Error("Expected accept to be true")
|
||||
}
|
||||
|
||||
if hook1.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||
@@ -117,10 +120,13 @@ func TestHookErrorHandling(t *testing.T) {
|
||||
conn := &Conn{}
|
||||
|
||||
// Test that error stops processing
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
accept, err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ProcessOnGet")
|
||||
}
|
||||
if accept {
|
||||
t.Error("Expected accept to be false")
|
||||
}
|
||||
|
||||
if errorHook.OnGetCalled != 1 {
|
||||
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||
|
@@ -434,6 +434,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
|
||||
now := time.Now()
|
||||
attempts := 0
|
||||
|
||||
// get hooks manager
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
||||
@@ -460,17 +466,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
||||
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
// Failed to process connection, discard it
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
if !acceptConn {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||
p.Put(ctx, cn)
|
||||
cn = nil
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint32(&p.stats.Hits, 1)
|
||||
@@ -486,14 +494,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
||||
acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true)
|
||||
// both errors and accept=false mean a hook rejected the connection
|
||||
// this should not happen with a new connection, but we handle it gracefully
|
||||
if err != nil || !acceptConn {
|
||||
// Failed to process connection, discard it
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accpet=%v, err=%v", newcn.GetID(), acceptConn, err)
|
||||
_ = p.CloseConn(newcn)
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
|
||||
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||
// in a handoff state at the moment.
|
||||
|
||||
// Check if connection is usable (not in a handoff state)
|
||||
// Should not happen since the pool will not return a connection that is not usable.
|
||||
if !conn.IsUsable() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||
if conn.ShouldHandoff() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
|
@@ -360,10 +360,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||
}
|
||||
if !acceptCon {
|
||||
t.Error("Connection should be accepted for normal connection")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||
@@ -385,10 +388,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
if acceptCon {
|
||||
t.Error("Connection should not be accepted when marked for handoff")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
@@ -416,10 +422,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
|
||||
// Test OnGet with pending handoff
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||
}
|
||||
if acceptCon {
|
||||
t.Error("Should not accept connection with pending handoff")
|
||||
}
|
||||
|
||||
// Test removing from pending map and clearing handoff state
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
@@ -432,10 +441,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
conn.SetUsable(true) // Make connection usable again
|
||||
|
||||
// Test OnGet without pending handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||
}
|
||||
if !acceptCon {
|
||||
t.Error("Should accept connection without pending handoff")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||
@@ -628,11 +640,15 @@ func TestConnectionHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// OnGet should succeed for usable connection
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptConn, err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
t.Error("Connection should be accepted when usable")
|
||||
}
|
||||
|
||||
// Mark connection for handoff
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
@@ -652,13 +668,17 @@ func TestConnectionHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// OnGet should fail for connection marked for handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("OnGet should fail for connection marked for handoff")
|
||||
}
|
||||
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
if acceptConn {
|
||||
t.Error("Connection should not be accepted when marked for handoff")
|
||||
}
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
@@ -678,11 +698,15 @@ func TestConnectionHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// OnGet should succeed again
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
t.Error("Connection should be accepted after handoff completion")
|
||||
}
|
||||
|
||||
t.Logf("Usable flag behavior test completed successfully")
|
||||
})
|
||||
|
||||
|
Reference in New Issue
Block a user