From 4bc6d335b8729eedb2a8d433e0e5ade6c3ec2a6c Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 17 Oct 2025 15:53:52 +0300 Subject: [PATCH] sync and async reauth based on conn lifecycle --- .../conn_reauth_credentials_listener.go | 19 ++++ internal/auth/streaming/manager_test.go | 101 ++++++++++++++++++ internal/auth/streaming/pool_hook.go | 53 ++++++--- 3 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 internal/auth/streaming/manager_test.go diff --git a/internal/auth/streaming/conn_reauth_credentials_listener.go b/internal/auth/streaming/conn_reauth_credentials_listener.go index 0f17818c..8bda93af 100644 --- a/internal/auth/streaming/conn_reauth_credentials_listener.go +++ b/internal/auth/streaming/conn_reauth_credentials_listener.go @@ -35,6 +35,25 @@ func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { return } + // this connection is not in use, so we can re-authenticate it + if c.conn.Used.CompareAndSwap(false, true) { + // try to acquire the connection for background operation + if c.conn.Usable.CompareAndSwap(true, false) { + err := c.reAuth(c.conn, credentials) + if err != nil { + c.OnError(err) + } + c.conn.Usable.Store(true) + c.conn.Used.Store(false) + return + } + c.conn.Used.Store(false) + } + // else if the connection is in use, mark it for re-authentication + // and connection pool hook will re-authenticate it when it is returned to the pool + // or in case the connection WAS in the pool, but handoff is in progress, the pool hook + // will re-authenticate it when the handoff is complete + // and the connection is acquired from the pool c.manager.MarkForReAuth(c.conn, func(err error) { if err != nil { c.OnError(err) diff --git a/internal/auth/streaming/manager_test.go b/internal/auth/streaming/manager_test.go new file mode 100644 index 00000000..e4ff813e --- /dev/null +++ b/internal/auth/streaming/manager_test.go @@ -0,0 +1,101 @@ +package streaming + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// Test that Listener returns the newly created listener, not nil +func TestManager_Listener_ReturnsNewListener(t *testing.T) { + // Create a mock pool + mockPool := &mockPooler{} + + // Create manager + manager := NewManager(mockPool, time.Second) + + // Create a mock connection + conn := &pool.Conn{} + + // Mock functions + reAuth := func(cn *pool.Conn, creds auth.Credentials) error { + return nil + } + + onErr := func(cn *pool.Conn, err error) { + } + + // Get listener - this should create a new one + listener, err := manager.Listener(conn, reAuth, onErr) + + // Verify no error + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Verify listener is not nil (this was the bug!) + if listener == nil { + t.Fatal("Expected listener to be non-nil, but got nil") + } + + // Verify it's the correct type + if _, ok := listener.(*ConnReAuthCredentialsListener); !ok { + t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener) + } + + // Get the same listener again - should return the existing one + listener2, err := manager.Listener(conn, reAuth, onErr) + if err != nil { + t.Fatalf("Expected no error on second call, got: %v", err) + } + + if listener2 == nil { + t.Fatal("Expected listener2 to be non-nil") + } + + // Should be the same instance + if listener != listener2 { + t.Error("Expected to get the same listener instance on second call") + } +} + +// Test that Listener returns error when conn is nil +func TestManager_Listener_NilConn(t *testing.T) { + mockPool := &mockPooler{} + manager := NewManager(mockPool, time.Second) + + listener, err := manager.Listener(nil, nil, nil) + + if err == nil { + t.Fatal("Expected error when conn is nil, got nil") + } + + if listener != nil { + t.Error("Expected listener to be nil when error occurs") + } + + expectedErr := "poolCn cannot be nil" + if err.Error() != expectedErr { + t.Errorf("Expected error message %q, got %q", expectedErr, err.Error()) + } +} + +// Mock pooler for testing +type mockPooler struct{} + +func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil } +func (m *mockPooler) CloseConn(*pool.Conn) error { return nil } +func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil } +func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {} +func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {} +func (m *mockPooler) Len() int { return 0 } +func (m *mockPooler) IdleLen() int { return 0 } +func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} } +func (m *mockPooler) Size() int { return 10 } +func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {} +func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {} +func (m *mockPooler) Close() error { return nil } + diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index 1b2813e8..7318589c 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -17,9 +17,15 @@ type ReAuthPoolHook struct { } func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { + workers := make(chan struct{}, poolSize) + // Initialize the workers channel with tokens (semaphore pattern) + for i := 0; i < poolSize; i++ { + workers <- struct{}{} + } + return &ReAuthPoolHook{ shouldReAuth: make(map[uint64]func(error)), - workers: make(chan struct{}, poolSize), + workers: workers, reAuthTimeout: reAuthTimeout, } @@ -31,44 +37,63 @@ func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { r.shouldReAuth[connID] = reAuthFn } -func (r *ReAuthPoolHook) ShouldReAuth(connID uint64) bool { - r.lock.RLock() - defer r.lock.RUnlock() - _, ok := r.shouldReAuth[connID] - return ok -} - func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) { r.lock.Lock() defer r.lock.Unlock() delete(r.shouldReAuth, connID) } -func (r *ReAuthPoolHook) OnGet(_ context.Context, _ *pool.Conn, _ bool) error { - // noop +func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) error { + // This connection was marked for reauth while in the pool, + // so we need to reauth it before returning it to the user. + r.lock.RLock() + reAuthFn, ok := r.shouldReAuth[conn.GetID()] + r.lock.RUnlock() + if ok { + // Clear the mark immediately to prevent duplicate reauth attempts + r.ClearReAuthMark(conn.GetID()) + reAuthFn(nil) + } return nil } func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { - if reAuthFn, ok := r.shouldReAuth[conn.GetID()]; ok { + // Check if reauth is needed and get the function with proper locking + r.lock.RLock() + reAuthFn, ok := r.shouldReAuth[conn.GetID()] + r.lock.RUnlock() + + if ok { + // Clear the mark immediately to prevent duplicate reauth attempts + r.ClearReAuthMark(conn.GetID()) + go func() { <-r.workers + defer func() { + r.workers <- struct{}{} + }() + var err error timeout := time.After(r.reAuthTimeout) + + // Try to acquire the connection (set Usable to false) for !conn.Usable.CompareAndSwap(true, false) { select { case <-timeout: + // Timeout occurred, cannot acquire connection err = pool.ErrConnUnusableTimeout + reAuthFn(err) + return default: time.Sleep(time.Millisecond) - // connection closed, cannot re-authenticate } } - reAuthFn(err) + // Successfully acquired the connection, perform reauth + reAuthFn(nil) + // Release the connection conn.Usable.Store(true) - r.workers <- struct{}{} }() }