mirror of
https://github.com/redis/go-redis.git
synced 2025-10-21 20:53:41 +03:00
be able to reject connection OnGet
This commit is contained in:
@@ -11,9 +11,12 @@ import (
|
|||||||
type ReAuthPoolHook struct {
|
type ReAuthPoolHook struct {
|
||||||
// conn id -> func() reauth func with error handling
|
// conn id -> func() reauth func with error handling
|
||||||
shouldReAuth map[uint64]func(error)
|
shouldReAuth map[uint64]func(error)
|
||||||
lock sync.RWMutex
|
shouldReAuthLock sync.RWMutex
|
||||||
workers chan struct{}
|
workers chan struct{}
|
||||||
reAuthTimeout time.Duration
|
reAuthTimeout time.Duration
|
||||||
|
// conn id -> bool
|
||||||
|
scheduledReAuth map[uint64]bool
|
||||||
|
scheduledLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||||
@@ -25,6 +28,7 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
|
|||||||
|
|
||||||
return &ReAuthPoolHook{
|
return &ReAuthPoolHook{
|
||||||
shouldReAuth: make(map[uint64]func(error)),
|
shouldReAuth: make(map[uint64]func(error)),
|
||||||
|
scheduledReAuth: make(map[uint64]bool),
|
||||||
workers: workers,
|
workers: workers,
|
||||||
reAuthTimeout: reAuthTimeout,
|
reAuthTimeout: reAuthTimeout,
|
||||||
}
|
}
|
||||||
@@ -32,44 +36,56 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
||||||
r.lock.Lock()
|
r.shouldReAuthLock.Lock()
|
||||||
defer r.lock.Unlock()
|
defer r.shouldReAuthLock.Unlock()
|
||||||
r.shouldReAuth[connID] = reAuthFn
|
r.shouldReAuth[connID] = reAuthFn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
|
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
|
||||||
r.lock.Lock()
|
r.shouldReAuthLock.Lock()
|
||||||
defer r.lock.Unlock()
|
defer r.shouldReAuthLock.Unlock()
|
||||||
delete(r.shouldReAuth, connID)
|
delete(r.shouldReAuth, connID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) error {
|
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||||
|
r.shouldReAuthLock.RLock()
|
||||||
|
_, ok := r.shouldReAuth[conn.GetID()]
|
||||||
|
r.shouldReAuthLock.RUnlock()
|
||||||
// This connection was marked for reauth while in the pool,
|
// This connection was marked for reauth while in the pool,
|
||||||
// so we need to reauth it before returning it to the user.
|
// reject the connection
|
||||||
r.lock.RLock()
|
|
||||||
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
|
||||||
r.lock.RUnlock()
|
|
||||||
if ok {
|
if ok {
|
||||||
// Clear the mark immediately to prevent duplicate reauth attempts
|
// simply reject the connection, it will be re-authenticated in OnPut
|
||||||
r.ClearReAuthMark(conn.GetID())
|
return false, nil
|
||||||
reAuthFn(nil)
|
|
||||||
}
|
}
|
||||||
return nil
|
r.scheduledLock.RLock()
|
||||||
|
hasScheduled, ok := r.scheduledReAuth[conn.GetID()]
|
||||||
|
r.scheduledLock.RUnlock()
|
||||||
|
// has scheduled reauth, reject the connection
|
||||||
|
if ok && hasScheduled {
|
||||||
|
// simply reject the connection, it will be re-authenticated in OnPut
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
||||||
// Check if reauth is needed and get the function with proper locking
|
// Check if reauth is needed and get the function with proper locking
|
||||||
r.lock.RLock()
|
r.shouldReAuthLock.RLock()
|
||||||
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
||||||
r.lock.RUnlock()
|
r.shouldReAuthLock.RUnlock()
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
|
r.scheduledLock.Lock()
|
||||||
|
r.scheduledReAuth[conn.GetID()] = true
|
||||||
|
r.scheduledLock.Unlock()
|
||||||
// Clear the mark immediately to prevent duplicate reauth attempts
|
// Clear the mark immediately to prevent duplicate reauth attempts
|
||||||
r.ClearReAuthMark(conn.GetID())
|
r.ClearReAuthMark(conn.GetID())
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-r.workers
|
<-r.workers
|
||||||
defer func() {
|
defer func() {
|
||||||
|
r.scheduledLock.Lock()
|
||||||
|
delete(r.scheduledReAuth, conn.GetID())
|
||||||
|
r.scheduledLock.Unlock()
|
||||||
r.workers <- struct{}{}
|
r.workers <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -103,6 +119,12 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
||||||
|
r.scheduledLock.Lock()
|
||||||
|
delete(r.scheduledReAuth, conn.GetID())
|
||||||
|
r.scheduledLock.Unlock()
|
||||||
|
r.shouldReAuthLock.Lock()
|
||||||
|
delete(r.shouldReAuth, conn.GetID())
|
||||||
|
r.shouldReAuthLock.Unlock()
|
||||||
r.ClearReAuthMark(conn.GetID())
|
r.ClearReAuthMark(conn.GetID())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -9,9 +9,13 @@ import (
|
|||||||
type PoolHook interface {
|
type PoolHook interface {
|
||||||
// OnGet is called when a connection is retrieved from the pool.
|
// OnGet is called when a connection is retrieved from the pool.
|
||||||
// It can modify the connection or return an error to prevent its use.
|
// It can modify the connection or return an error to prevent its use.
|
||||||
|
// The accept flag can be used to prevent the connection from being used.
|
||||||
|
// On Accept = false the connection is rejected and returned to the pool.
|
||||||
|
// The error can be used to prevent the connection from being used and returned to the pool.
|
||||||
|
// On Errors, the connection is removed from the pool.
|
||||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
|
||||||
|
|
||||||
// OnPut is called when a connection is returned to the pool.
|
// OnPut is called when a connection is returned to the pool.
|
||||||
// It returns whether the connection should be pooled and whether it should be removed.
|
// It returns whether the connection should be pooled and whether it should be removed.
|
||||||
@@ -60,16 +64,21 @@ func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
|||||||
|
|
||||||
// ProcessOnGet calls all OnGet hooks in order.
|
// ProcessOnGet calls all OnGet hooks in order.
|
||||||
// If any hook returns an error, processing stops and the error is returned.
|
// If any hook returns an error, processing stops and the error is returned.
|
||||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
|
||||||
phm.hooksMu.RLock()
|
phm.hooksMu.RLock()
|
||||||
defer phm.hooksMu.RUnlock()
|
defer phm.hooksMu.RUnlock()
|
||||||
|
|
||||||
for _, hook := range phm.hooks {
|
for _, hook := range phm.hooks {
|
||||||
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
|
||||||
return err
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !acceptConn {
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessOnPut calls all OnPut hooks in order.
|
// ProcessOnPut calls all OnPut hooks in order.
|
||||||
|
@@ -56,10 +56,13 @@ func TestPoolHookManager(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn := &Conn{} // Mock connection
|
conn := &Conn{} // Mock connection
|
||||||
|
|
||||||
err := manager.ProcessOnGet(ctx, conn, false)
|
accept, err := manager.ProcessOnGet(ctx, conn, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("ProcessOnGet should not error: %v", err)
|
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||||
}
|
}
|
||||||
|
if !accept {
|
||||||
|
t.Error("Expected accept to be true")
|
||||||
|
}
|
||||||
|
|
||||||
if hook1.OnGetCalled != 1 {
|
if hook1.OnGetCalled != 1 {
|
||||||
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||||
@@ -117,10 +120,13 @@ func TestHookErrorHandling(t *testing.T) {
|
|||||||
conn := &Conn{}
|
conn := &Conn{}
|
||||||
|
|
||||||
// Test that error stops processing
|
// Test that error stops processing
|
||||||
err := manager.ProcessOnGet(ctx, conn, false)
|
accept, err := manager.ProcessOnGet(ctx, conn, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error from ProcessOnGet")
|
t.Error("Expected error from ProcessOnGet")
|
||||||
}
|
}
|
||||||
|
if accept {
|
||||||
|
t.Error("Expected accept to be false")
|
||||||
|
}
|
||||||
|
|
||||||
if errorHook.OnGetCalled != 1 {
|
if errorHook.OnGetCalled != 1 {
|
||||||
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||||
|
@@ -434,6 +434,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
attempts := 0
|
attempts := 0
|
||||||
|
|
||||||
|
// get hooks manager
|
||||||
|
p.hookManagerMu.RLock()
|
||||||
|
hookManager := p.hookManager
|
||||||
|
p.hookManagerMu.RUnlock()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if attempts >= getAttempts {
|
if attempts >= getAttempts {
|
||||||
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
||||||
@@ -460,17 +466,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process connection using the hooks system
|
// Process connection using the hooks system
|
||||||
p.hookManagerMu.RLock()
|
|
||||||
hookManager := p.hookManager
|
|
||||||
p.hookManagerMu.RUnlock()
|
|
||||||
|
|
||||||
if hookManager != nil {
|
if hookManager != nil {
|
||||||
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
|
||||||
|
if err != nil {
|
||||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||||
// Failed to process connection, discard it
|
|
||||||
_ = p.CloseConn(cn)
|
_ = p.CloseConn(cn)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !acceptConn {
|
||||||
|
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||||
|
p.Put(ctx, cn)
|
||||||
|
cn = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
atomic.AddUint32(&p.stats.Hits, 1)
|
atomic.AddUint32(&p.stats.Hits, 1)
|
||||||
@@ -486,14 +494,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process connection using the hooks system
|
// Process connection using the hooks system
|
||||||
p.hookManagerMu.RLock()
|
|
||||||
hookManager := p.hookManager
|
|
||||||
p.hookManagerMu.RUnlock()
|
|
||||||
|
|
||||||
if hookManager != nil {
|
if hookManager != nil {
|
||||||
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true)
|
||||||
|
// both errors and accept=false mean a hook rejected the connection
|
||||||
|
// this should not happen with a new connection, but we handle it gracefully
|
||||||
|
if err != nil || !acceptConn {
|
||||||
// Failed to process connection, discard it
|
// Failed to process connection, discard it
|
||||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
|
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accpet=%v, err=%v", newcn.GetID(), acceptConn, err)
|
||||||
_ = p.CloseConn(newcn)
|
_ = p.CloseConn(newcn)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnGet is called when a connection is retrieved from the pool
|
// OnGet is called when a connection is retrieved from the pool
|
||||||
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
|
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||||
// in a handoff state at the moment.
|
// in a handoff state at the moment.
|
||||||
|
|
||||||
// Check if connection is usable (not in a handoff state)
|
// Check if connection is usable (not in a handoff state)
|
||||||
// Should not happen since the pool will not return a connection that is not usable.
|
// Should not happen since the pool will not return a connection that is not usable.
|
||||||
if !conn.IsUsable() {
|
if !conn.IsUsable() {
|
||||||
return ErrConnectionMarkedForHandoff
|
return false, ErrConnectionMarkedForHandoff
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||||
if conn.ShouldHandoff() {
|
if conn.ShouldHandoff() {
|
||||||
return ErrConnectionMarkedForHandoff
|
return false, ErrConnectionMarkedForHandoff
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPut is called when a connection is returned to the pool
|
// OnPut is called when a connection is returned to the pool
|
||||||
|
@@ -360,10 +360,13 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
conn := createMockPoolConnection()
|
conn := createMockPoolConnection()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
err := processor.OnGet(ctx, conn, false)
|
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("OnGet should not error for normal connection: %v", err)
|
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||||
}
|
}
|
||||||
|
if !acceptCon {
|
||||||
|
t.Error("Connection should be accepted for normal connection")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||||
@@ -385,10 +388,13 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
err := processor.OnGet(ctx, conn, false)
|
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||||
if err != ErrConnectionMarkedForHandoff {
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||||
}
|
}
|
||||||
|
if acceptCon {
|
||||||
|
t.Error("Connection should not be accepted when marked for handoff")
|
||||||
|
}
|
||||||
|
|
||||||
// Clean up
|
// Clean up
|
||||||
processor.GetPendingMap().Delete(conn)
|
processor.GetPendingMap().Delete(conn)
|
||||||
@@ -416,10 +422,13 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
|
|
||||||
// Test OnGet with pending handoff
|
// Test OnGet with pending handoff
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
err := processor.OnGet(ctx, conn, false)
|
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||||
if err != ErrConnectionMarkedForHandoff {
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||||
}
|
}
|
||||||
|
if acceptCon {
|
||||||
|
t.Error("Should not accept connection with pending handoff")
|
||||||
|
}
|
||||||
|
|
||||||
// Test removing from pending map and clearing handoff state
|
// Test removing from pending map and clearing handoff state
|
||||||
processor.GetPendingMap().Delete(conn)
|
processor.GetPendingMap().Delete(conn)
|
||||||
@@ -432,10 +441,13 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
conn.SetUsable(true) // Make connection usable again
|
conn.SetUsable(true) // Make connection usable again
|
||||||
|
|
||||||
// Test OnGet without pending handoff
|
// Test OnGet without pending handoff
|
||||||
err = processor.OnGet(ctx, conn, false)
|
acceptCon, err = processor.OnGet(ctx, conn, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Should not return error for non-pending connection: %v", err)
|
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||||
}
|
}
|
||||||
|
if !acceptCon {
|
||||||
|
t.Error("Should accept connection without pending handoff")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||||
@@ -628,11 +640,15 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnGet should succeed for usable connection
|
// OnGet should succeed for usable connection
|
||||||
err := processor.OnGet(ctx, conn, false)
|
acceptConn, err := processor.OnGet(ctx, conn, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !acceptConn {
|
||||||
|
t.Error("Connection should be accepted when usable")
|
||||||
|
}
|
||||||
|
|
||||||
// Mark connection for handoff
|
// Mark connection for handoff
|
||||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
@@ -652,13 +668,17 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnGet should fail for connection marked for handoff
|
// OnGet should fail for connection marked for handoff
|
||||||
err = processor.OnGet(ctx, conn, false)
|
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("OnGet should fail for connection marked for handoff")
|
t.Error("OnGet should fail for connection marked for handoff")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != ErrConnectionMarkedForHandoff {
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||||
}
|
}
|
||||||
|
if acceptConn {
|
||||||
|
t.Error("Connection should not be accepted when marked for handoff")
|
||||||
|
}
|
||||||
|
|
||||||
// Process the connection to trigger handoff
|
// Process the connection to trigger handoff
|
||||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
@@ -678,11 +698,15 @@ func TestConnectionHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnGet should succeed again
|
// OnGet should succeed again
|
||||||
err = processor.OnGet(ctx, conn, false)
|
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !acceptConn {
|
||||||
|
t.Error("Connection should be accepted after handoff completion")
|
||||||
|
}
|
||||||
|
|
||||||
t.Logf("Usable flag behavior test completed successfully")
|
t.Logf("Usable flag behavior test completed successfully")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user