mirror of
https://github.com/redis/go-redis.git
synced 2025-10-20 09:52:25 +03:00
fix(pool): wip, pool reauth should not interfere with handoff
This commit is contained in:
87
auth/conn_reauth_credentials_listener.go
Normal file
87
auth/conn_reauth_credentials_listener.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
|
||||
// It is used to re-authenticate the credentials when they are updated.
|
||||
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
|
||||
// It contains:
|
||||
// - reAuth: a function that takes the new credentials and returns an error if any.
|
||||
// - onErr: a function that takes an error and handles it.
|
||||
// - conn: the connection to re-authenticate.
|
||||
type ConnReAuthCredentialsListener struct {
|
||||
reAuth func(conn *pool.Conn, credentials Credentials) error
|
||||
onErr func(conn *pool.Conn, err error)
|
||||
conn *pool.Conn
|
||||
}
|
||||
|
||||
// OnNext is called when the credentials are updated.
|
||||
// It calls the reAuth function with the new credentials.
|
||||
// If the reAuth function returns an error, it calls the onErr function with the error.
|
||||
func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
|
||||
if c.conn.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
if c.reAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
timeout := time.After(1 * time.Second)
|
||||
// wait for the connection to be usable
|
||||
// this is important because the connection pool may be in the process of reconnecting the connection
|
||||
// and we don't want to interfere with that process
|
||||
// but we also don't want to block for too long, so incorporate a timeout
|
||||
for {
|
||||
// we were able to mark the connection as unusable
|
||||
if c.conn.Usable.CompareAndSwap(true, false) {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
err = pool.ErrConnUnusableTimeout
|
||||
break
|
||||
default:
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
// we set the usable flag, so restore it back to usable after we're done
|
||||
defer c.conn.SetUsable(true)
|
||||
|
||||
err = c.reAuth(c.conn, credentials)
|
||||
if err != nil {
|
||||
c.OnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OnError is called when an error occurs.
|
||||
// It can be called from both the credentials provider and the reAuth function.
|
||||
func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
||||
if c.onErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.onErr(c.conn, err)
|
||||
}
|
||||
|
||||
// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
|
||||
// Implements the auth.CredentialsListener interface.
|
||||
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
|
||||
return &ConnReAuthCredentialsListener{
|
||||
conn: conn,
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
2
error.go
2
error.go
@@ -112,6 +112,8 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
return false
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
return true
|
||||
case pool.ErrConnUnusableTimeout:
|
||||
return true
|
||||
}
|
||||
|
||||
if isRedisError(err) {
|
||||
|
@@ -40,6 +40,9 @@ func generateConnID() uint64 {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
// Connection identifier for unique tracking
|
||||
id uint64 // Unique numeric identifier for this connection
|
||||
|
||||
usedAt int64 // atomic
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
@@ -54,7 +57,9 @@ type Conn struct {
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
Usable atomic.Bool
|
||||
Inited atomic.Bool
|
||||
|
||||
pooled bool
|
||||
pubsub bool
|
||||
closed atomic.Bool
|
||||
@@ -75,11 +80,7 @@ type Conn struct {
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
// Connection identifier for unique tracking
|
||||
id uint64 // Unique numeric identifier for this connection
|
||||
|
||||
// Handoff state - using atomic operations for lock-free access
|
||||
usableAtomic atomic.Bool // Connection usability state
|
||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||
|
||||
// Atomic handoff state to prevent race conditions
|
||||
@@ -116,7 +117,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
// Initialize atomic state
|
||||
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
||||
cn.Usable.Store(false) // false initially, set to true after initialization
|
||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||
|
||||
// Initialize handoff state atomically
|
||||
@@ -162,12 +163,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
|
||||
// isUsable returns true if the connection is safe to use (lock-free).
|
||||
func (cn *Conn) isUsable() bool {
|
||||
return cn.usableAtomic.Load()
|
||||
return cn.Usable.Load()
|
||||
}
|
||||
|
||||
// setUsable sets the usable flag atomically (lock-free).
|
||||
func (cn *Conn) setUsable(usable bool) {
|
||||
cn.usableAtomic.Store(usable)
|
||||
cn.Usable.Store(usable)
|
||||
}
|
||||
|
||||
// getHandoffState returns the current handoff state atomically (lock-free).
|
||||
@@ -456,6 +457,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
const baseDelay = time.Microsecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// first we need to mark the connection as not usable
|
||||
// to prevent the pool from returning it to the caller
|
||||
if !cn.Usable.CompareAndSwap(true, false) {
|
||||
continue
|
||||
}
|
||||
|
||||
currentState := cn.getHandoffState()
|
||||
|
||||
// Check if marked for handoff
|
||||
@@ -472,7 +479,6 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
|
||||
// Atomic compare-and-swap to update state
|
||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
cn.setUsable(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -24,6 +24,9 @@ var (
|
||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||
|
||||
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
|
||||
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
|
||||
|
||||
// popAttempts is the maximum number of attempts to find a usable connection
|
||||
// when popping from the idle connection pool. This handles cases where connections
|
||||
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
||||
|
60
redis.go
60
redis.go
@@ -224,6 +224,9 @@ type baseClient struct {
|
||||
// Maintenance notifications manager
|
||||
maintNotificationsManager *maintnotifications.Manager
|
||||
maintNotificationsManagerLock sync.RWMutex
|
||||
|
||||
credListeners map[uint64]auth.CredentialsListener
|
||||
credListenersLock sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *baseClient) clone() *baseClient {
|
||||
@@ -237,6 +240,7 @@ func (c *baseClient) clone() *baseClient {
|
||||
onClose: c.onClose,
|
||||
pushProcessor: c.pushProcessor,
|
||||
maintNotificationsManager: maintNotificationsManager,
|
||||
credListeners: c.credListeners,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -296,18 +300,43 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||
return auth.NewReAuthCredentialsListener(
|
||||
c.reAuthConnection(poolCn),
|
||||
c.onAuthenticationErr(poolCn),
|
||||
// connReAuthCredentialsListener returns a credentials listener that can be used to re-authenticate the connection.
|
||||
// The credentials listener is stored in a map, so that it can be reused for multiple connections.
|
||||
// The credentials listener is removed from the map when the connection is closed.
|
||||
func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.CredentialsListener, func()) {
|
||||
c.credListenersLock.RLock()
|
||||
credListener, ok := c.credListeners[poolCn.GetID()]
|
||||
c.credListenersLock.RUnlock()
|
||||
if ok {
|
||||
return credListener.(auth.CredentialsListener), func() {
|
||||
c.removeCredListener(poolCn)
|
||||
}
|
||||
}
|
||||
c.credListenersLock.Lock()
|
||||
defer c.credListenersLock.Unlock()
|
||||
newCredListener := auth.NewConnReAuthCredentialsListener(
|
||||
poolCn,
|
||||
c.reAuthConnection(),
|
||||
c.onAuthenticationErr(),
|
||||
)
|
||||
c.credListeners[poolCn.GetID()] = newCredListener
|
||||
return newCredListener, func() {
|
||||
c.removeCredListener(poolCn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||
return func(credentials auth.Credentials) error {
|
||||
func (c *baseClient) removeCredListener(poolCn *pool.Conn) {
|
||||
c.credListenersLock.Lock()
|
||||
defer c.credListenersLock.Unlock()
|
||||
delete(c.credListeners, poolCn.GetID())
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||
var err error
|
||||
username, password := credentials.BasicAuth()
|
||||
ctx := context.Background()
|
||||
|
||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||
// hooksMixin are intentionally empty here
|
||||
cn := newConn(c.opt, connPool, nil)
|
||||
@@ -320,8 +349,8 @@ func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.C
|
||||
return err
|
||||
}
|
||||
}
|
||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||
return func(err error) {
|
||||
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
|
||||
return func(poolCn *pool.Conn, err error) {
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
// Close the connection to force a reconnection.
|
||||
@@ -372,13 +401,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credListener, removeCredListener := c.connReAuthCredentialsListener(cn)
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||
Subscribe(credListener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
}
|
||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
||||
|
||||
unsubscribe := func() error {
|
||||
removeCredListener()
|
||||
return unsubscribeFromCredentialsProvider()
|
||||
}
|
||||
c.onClose = c.wrappedOnClose(unsubscribe)
|
||||
cn.SetOnClose(unsubscribe)
|
||||
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
@@ -496,6 +532,8 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
}
|
||||
}
|
||||
|
||||
// mark the connection as usable and inited
|
||||
// once returned to the pool as idle, this connection can be used by other clients
|
||||
cn.SetUsable(true)
|
||||
cn.Inited.Store(true)
|
||||
|
||||
|
Reference in New Issue
Block a user