mirror of
https://github.com/redis/go-redis.git
synced 2025-10-20 09:52:25 +03:00
sync and async reauth based on conn lifecycle
This commit is contained in:
@@ -35,6 +35,25 @@ func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
|
||||
return
|
||||
}
|
||||
|
||||
// this connection is not in use, so we can re-authenticate it
|
||||
if c.conn.Used.CompareAndSwap(false, true) {
|
||||
// try to acquire the connection for background operation
|
||||
if c.conn.Usable.CompareAndSwap(true, false) {
|
||||
err := c.reAuth(c.conn, credentials)
|
||||
if err != nil {
|
||||
c.OnError(err)
|
||||
}
|
||||
c.conn.Usable.Store(true)
|
||||
c.conn.Used.Store(false)
|
||||
return
|
||||
}
|
||||
c.conn.Used.Store(false)
|
||||
}
|
||||
// else if the connection is in use, mark it for re-authentication
|
||||
// and connection pool hook will re-authenticate it when it is returned to the pool
|
||||
// or in case the connection WAS in the pool, but handoff is in progress, the pool hook
|
||||
// will re-authenticate it when the handoff is complete
|
||||
// and the connection is acquired from the pool
|
||||
c.manager.MarkForReAuth(c.conn, func(err error) {
|
||||
if err != nil {
|
||||
c.OnError(err)
|
||||
|
101
internal/auth/streaming/manager_test.go
Normal file
101
internal/auth/streaming/manager_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// Test that Listener returns the newly created listener, not nil
|
||||
func TestManager_Listener_ReturnsNewListener(t *testing.T) {
|
||||
// Create a mock pool
|
||||
mockPool := &mockPooler{}
|
||||
|
||||
// Create manager
|
||||
manager := NewManager(mockPool, time.Second)
|
||||
|
||||
// Create a mock connection
|
||||
conn := &pool.Conn{}
|
||||
|
||||
// Mock functions
|
||||
reAuth := func(cn *pool.Conn, creds auth.Credentials) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
onErr := func(cn *pool.Conn, err error) {
|
||||
}
|
||||
|
||||
// Get listener - this should create a new one
|
||||
listener, err := manager.Listener(conn, reAuth, onErr)
|
||||
|
||||
// Verify no error
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify listener is not nil (this was the bug!)
|
||||
if listener == nil {
|
||||
t.Fatal("Expected listener to be non-nil, but got nil")
|
||||
}
|
||||
|
||||
// Verify it's the correct type
|
||||
if _, ok := listener.(*ConnReAuthCredentialsListener); !ok {
|
||||
t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener)
|
||||
}
|
||||
|
||||
// Get the same listener again - should return the existing one
|
||||
listener2, err := manager.Listener(conn, reAuth, onErr)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on second call, got: %v", err)
|
||||
}
|
||||
|
||||
if listener2 == nil {
|
||||
t.Fatal("Expected listener2 to be non-nil")
|
||||
}
|
||||
|
||||
// Should be the same instance
|
||||
if listener != listener2 {
|
||||
t.Error("Expected to get the same listener instance on second call")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Listener returns error when conn is nil
|
||||
func TestManager_Listener_NilConn(t *testing.T) {
|
||||
mockPool := &mockPooler{}
|
||||
manager := NewManager(mockPool, time.Second)
|
||||
|
||||
listener, err := manager.Listener(nil, nil, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when conn is nil, got nil")
|
||||
}
|
||||
|
||||
if listener != nil {
|
||||
t.Error("Expected listener to be nil when error occurs")
|
||||
}
|
||||
|
||||
expectedErr := "poolCn cannot be nil"
|
||||
if err.Error() != expectedErr {
|
||||
t.Errorf("Expected error message %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Mock pooler for testing
|
||||
type mockPooler struct{}
|
||||
|
||||
func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil }
|
||||
func (m *mockPooler) CloseConn(*pool.Conn) error { return nil }
|
||||
func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil }
|
||||
func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {}
|
||||
func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {}
|
||||
func (m *mockPooler) Len() int { return 0 }
|
||||
func (m *mockPooler) IdleLen() int { return 0 }
|
||||
func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} }
|
||||
func (m *mockPooler) Size() int { return 10 }
|
||||
func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {}
|
||||
func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {}
|
||||
func (m *mockPooler) Close() error { return nil }
|
||||
|
@@ -17,9 +17,15 @@ type ReAuthPoolHook struct {
|
||||
}
|
||||
|
||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||
workers := make(chan struct{}, poolSize)
|
||||
// Initialize the workers channel with tokens (semaphore pattern)
|
||||
for i := 0; i < poolSize; i++ {
|
||||
workers <- struct{}{}
|
||||
}
|
||||
|
||||
return &ReAuthPoolHook{
|
||||
shouldReAuth: make(map[uint64]func(error)),
|
||||
workers: make(chan struct{}, poolSize),
|
||||
workers: workers,
|
||||
reAuthTimeout: reAuthTimeout,
|
||||
}
|
||||
|
||||
@@ -31,44 +37,63 @@ func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
||||
r.shouldReAuth[connID] = reAuthFn
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) ShouldReAuth(connID uint64) bool {
|
||||
r.lock.RLock()
|
||||
defer r.lock.RUnlock()
|
||||
_, ok := r.shouldReAuth[connID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
delete(r.shouldReAuth, connID)
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, _ *pool.Conn, _ bool) error {
|
||||
// noop
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) error {
|
||||
// This connection was marked for reauth while in the pool,
|
||||
// so we need to reauth it before returning it to the user.
|
||||
r.lock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
||||
r.lock.RUnlock()
|
||||
if ok {
|
||||
// Clear the mark immediately to prevent duplicate reauth attempts
|
||||
r.ClearReAuthMark(conn.GetID())
|
||||
reAuthFn(nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
||||
if reAuthFn, ok := r.shouldReAuth[conn.GetID()]; ok {
|
||||
// Check if reauth is needed and get the function with proper locking
|
||||
r.lock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
|
||||
r.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
// Clear the mark immediately to prevent duplicate reauth attempts
|
||||
r.ClearReAuthMark(conn.GetID())
|
||||
|
||||
go func() {
|
||||
<-r.workers
|
||||
defer func() {
|
||||
r.workers <- struct{}{}
|
||||
}()
|
||||
|
||||
var err error
|
||||
timeout := time.After(r.reAuthTimeout)
|
||||
|
||||
// Try to acquire the connection (set Usable to false)
|
||||
for !conn.Usable.CompareAndSwap(true, false) {
|
||||
select {
|
||||
case <-timeout:
|
||||
// Timeout occurred, cannot acquire connection
|
||||
err = pool.ErrConnUnusableTimeout
|
||||
reAuthFn(err)
|
||||
return
|
||||
default:
|
||||
time.Sleep(time.Millisecond)
|
||||
// connection closed, cannot re-authenticate
|
||||
}
|
||||
}
|
||||
|
||||
reAuthFn(err)
|
||||
// Successfully acquired the connection, perform reauth
|
||||
reAuthFn(nil)
|
||||
|
||||
// Release the connection
|
||||
conn.Usable.Store(true)
|
||||
r.workers <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user