diff --git a/internal/auth/streaming/conn_reauth_credentials_listener.go b/internal/auth/streaming/conn_reauth_credentials_listener.go index d0ac8a84..68274ef9 100644 --- a/internal/auth/streaming/conn_reauth_credentials_listener.go +++ b/internal/auth/streaming/conn_reauth_credentials_listener.go @@ -27,11 +27,7 @@ type ConnReAuthCredentialsListener struct { // It calls the reAuth function with the new credentials. // If the reAuth function returns an error, it calls the onErr function with the error. func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { - if c.conn.IsClosed() { - return - } - - if c.reAuth == nil { + if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil { return } @@ -41,17 +37,20 @@ func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { // The connection pool hook will re-authenticate the connection when it is // returned to the pool in a clean, idle state. c.manager.MarkForReAuth(c.conn, func(err error) { + // err is from connection acquisition (timeout, etc.) if err != nil { + // Log the error c.OnError(err) return } + // err is from reauth command execution err = c.reAuth(c.conn, credentials) if err != nil { + // Log the error c.OnError(err) return } }) - } // OnError is called when an error occurs. diff --git a/internal/auth/streaming/manager.go b/internal/auth/streaming/manager.go index 3f529d15..375bf994 100644 --- a/internal/auth/streaming/manager.go +++ b/internal/auth/streaming/manager.go @@ -15,11 +15,13 @@ type Manager struct { } func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager { - return &Manager{ + m := &Manager{ pool: pl, poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout), credentialsListeners: NewCredentialsListeners(), } + m.poolHookRef.manager = m + return m } func (m *Manager) PoolHook() pool.PoolHook { @@ -35,6 +37,10 @@ func (m *Manager) Listener( return nil, errors.New("poolCn cannot be nil") } connID := poolCn.GetID() + // if we reconnect the underlying network connection, the streaming credentials listener will continue to work + // so we can get the old listener from the cache and use it. + + // subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op listener, ok := m.credentialsListeners.Get(connID) if !ok || listener == nil { newCredListener := &ConnReAuthCredentialsListener{ @@ -54,3 +60,7 @@ func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { connID := poolCn.GetID() m.poolHookRef.MarkForReAuth(connID, reAuthFn) } + +func (m *Manager) RemoveListener(connID uint64) { + m.credentialsListeners.Remove(connID) +} diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index 46751ee2..db99cb17 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" ) @@ -17,6 +18,9 @@ type ReAuthPoolHook struct { // conn id -> bool scheduledReAuth map[uint64]bool scheduledLock sync.RWMutex + + // for cleanup + manager *Manager } func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { @@ -32,7 +36,6 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo workers: workers, reAuthTimeout: reAuthTimeout, } - } func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { @@ -41,27 +44,22 @@ func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { r.shouldReAuth[connID] = reAuthFn } -func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) { - r.shouldReAuthLock.Lock() - defer r.shouldReAuthLock.Unlock() - delete(r.shouldReAuth, connID) -} - func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { + connID := conn.GetID() r.shouldReAuthLock.RLock() - _, ok := r.shouldReAuth[conn.GetID()] + _, shouldReAuth := r.shouldReAuth[connID] r.shouldReAuthLock.RUnlock() // This connection was marked for reauth while in the pool, // reject the connection - if ok { + if shouldReAuth { // simply reject the connection, it will be re-authenticated in OnPut return false, nil } r.scheduledLock.RLock() - hasScheduled, ok := r.scheduledReAuth[conn.GetID()] + _, hasScheduled := r.scheduledReAuth[connID] r.scheduledLock.RUnlock() // has scheduled reauth, reject the connection - if ok && hasScheduled { + if hasScheduled { // simply reject the connection, it currently has a reauth scheduled // and the worker is waiting for slot to execute the reauth return false, nil @@ -70,22 +68,38 @@ func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (acce } func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { + if conn == nil { + // noop + return true, false, nil + } + connID := conn.GetID() // Check if reauth is needed and get the function with proper locking r.shouldReAuthLock.RLock() - reAuthFn, ok := r.shouldReAuth[conn.GetID()] + r.scheduledLock.RLock() + reAuthFn, ok := r.shouldReAuth[connID] r.shouldReAuthLock.RUnlock() if ok { + r.shouldReAuthLock.Lock() r.scheduledLock.Lock() - r.scheduledReAuth[conn.GetID()] = true + r.scheduledReAuth[connID] = true + delete(r.shouldReAuth, connID) r.scheduledLock.Unlock() - // Clear the mark immediately to prevent duplicate reauth attempts - r.ClearReAuthMark(conn.GetID()) + r.shouldReAuthLock.Unlock() go func() { <-r.workers + // safety first + if conn == nil || (conn != nil && conn.IsClosed()) { + r.workers <- struct{}{} + return + } defer func() { + if rec := recover(); rec != nil { + // once again - safety first + internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec) + } r.scheduledLock.Lock() - delete(r.scheduledReAuth, conn.GetID()) + delete(r.scheduledReAuth, connID) r.scheduledLock.Unlock() r.workers <- struct{}{} }() @@ -96,7 +110,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, // Try to acquire the connection // We need to ensure the connection is both Usable and not Used // to prevent data races with concurrent operations - const baseDelay = time.Microsecond + const baseDelay = 10 * time.Microsecond acquired := false attempt := 0 for !acquired { @@ -108,36 +122,33 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, return default: // Try to acquire: set Usable=false, then check Used - if conn.Usable.CompareAndSwap(true, false) { - if !conn.Used.Load() { + if conn.CompareAndSwapUsable(true, false) { + if !conn.IsUsed() { acquired = true } else { // Release Usable and retry with exponential backoff - conn.Usable.Store(true) - if attempt > 0 { - // Exponential backoff: 1, 2, 4, 8... up to 512 microseconds - delay := baseDelay * time.Duration(1< 0 { - // Exponential backoff: 1, 2, 4, 8... up to 512 microseconds - delay := baseDelay * time.Duration(1<