mirror of
https://github.com/redis/go-redis.git
synced 2025-10-18 22:08:50 +03:00
fix(pool): wip, pool reauth should not interfere with handoff
This commit is contained in:
87
auth/conn_reauth_credentials_listener.go
Normal file
87
auth/conn_reauth_credentials_listener.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
|
||||||
|
// It is used to re-authenticate the credentials when they are updated.
|
||||||
|
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
|
||||||
|
// It contains:
|
||||||
|
// - reAuth: a function that takes the new credentials and returns an error if any.
|
||||||
|
// - onErr: a function that takes an error and handles it.
|
||||||
|
// - conn: the connection to re-authenticate.
|
||||||
|
type ConnReAuthCredentialsListener struct {
|
||||||
|
reAuth func(conn *pool.Conn, credentials Credentials) error
|
||||||
|
onErr func(conn *pool.Conn, err error)
|
||||||
|
conn *pool.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnNext is called when the credentials are updated.
|
||||||
|
// It calls the reAuth function with the new credentials.
|
||||||
|
// If the reAuth function returns an error, it calls the onErr function with the error.
|
||||||
|
func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
|
||||||
|
if c.conn.IsClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.reAuth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
// wait for the connection to be usable
|
||||||
|
// this is important because the connection pool may be in the process of reconnecting the connection
|
||||||
|
// and we don't want to interfere with that process
|
||||||
|
// but we also don't want to block for too long, so incorporate a timeout
|
||||||
|
for {
|
||||||
|
// we were able to mark the connection as unusable
|
||||||
|
if c.conn.Usable.CompareAndSwap(true, false) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
err = pool.ErrConnUnusableTimeout
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.OnError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// we set the usable flag, so restore it back to usable after we're done
|
||||||
|
defer c.conn.SetUsable(true)
|
||||||
|
|
||||||
|
err = c.reAuth(c.conn, credentials)
|
||||||
|
if err != nil {
|
||||||
|
c.OnError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnError is called when an error occurs.
|
||||||
|
// It can be called from both the credentials provider and the reAuth function.
|
||||||
|
func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
||||||
|
if c.onErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.onErr(c.conn, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
|
||||||
|
// Implements the auth.CredentialsListener interface.
|
||||||
|
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
|
||||||
|
return &ConnReAuthCredentialsListener{
|
||||||
|
conn: conn,
|
||||||
|
reAuth: reAuth,
|
||||||
|
onErr: onErr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||||
|
var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
2
error.go
2
error.go
@@ -112,6 +112,8 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
|
|||||||
return false
|
return false
|
||||||
case context.Canceled, context.DeadlineExceeded:
|
case context.Canceled, context.DeadlineExceeded:
|
||||||
return true
|
return true
|
||||||
|
case pool.ErrConnUnusableTimeout:
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRedisError(err) {
|
if isRedisError(err) {
|
||||||
|
@@ -40,6 +40,9 @@ func generateConnID() uint64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
|
// Connection identifier for unique tracking
|
||||||
|
id uint64 // Unique numeric identifier for this connection
|
||||||
|
|
||||||
usedAt int64 // atomic
|
usedAt int64 // atomic
|
||||||
|
|
||||||
// Lock-free netConn access using atomic.Value
|
// Lock-free netConn access using atomic.Value
|
||||||
@@ -54,7 +57,9 @@ type Conn struct {
|
|||||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||||
readerMu sync.RWMutex
|
readerMu sync.RWMutex
|
||||||
|
|
||||||
Inited atomic.Bool
|
Usable atomic.Bool
|
||||||
|
Inited atomic.Bool
|
||||||
|
|
||||||
pooled bool
|
pooled bool
|
||||||
pubsub bool
|
pubsub bool
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
@@ -75,18 +80,14 @@ type Conn struct {
|
|||||||
// Connection initialization function for reconnections
|
// Connection initialization function for reconnections
|
||||||
initConnFunc func(context.Context, *Conn) error
|
initConnFunc func(context.Context, *Conn) error
|
||||||
|
|
||||||
// Connection identifier for unique tracking
|
|
||||||
id uint64 // Unique numeric identifier for this connection
|
|
||||||
|
|
||||||
// Handoff state - using atomic operations for lock-free access
|
// Handoff state - using atomic operations for lock-free access
|
||||||
usableAtomic atomic.Bool // Connection usability state
|
|
||||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||||
|
|
||||||
// Atomic handoff state to prevent race conditions
|
// Atomic handoff state to prevent race conditions
|
||||||
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
|
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
|
||||||
handoffStateAtomic atomic.Value // stores *HandoffState
|
handoffStateAtomic atomic.Value // stores *HandoffState
|
||||||
|
|
||||||
onClose func() error
|
onClose func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(netConn net.Conn) *Conn {
|
func NewConn(netConn net.Conn) *Conn {
|
||||||
@@ -116,7 +117,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
|||||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||||
|
|
||||||
// Initialize atomic state
|
// Initialize atomic state
|
||||||
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
cn.Usable.Store(false) // false initially, set to true after initialization
|
||||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||||
|
|
||||||
// Initialize handoff state atomically
|
// Initialize handoff state atomically
|
||||||
@@ -162,12 +163,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
|
|||||||
|
|
||||||
// isUsable returns true if the connection is safe to use (lock-free).
|
// isUsable returns true if the connection is safe to use (lock-free).
|
||||||
func (cn *Conn) isUsable() bool {
|
func (cn *Conn) isUsable() bool {
|
||||||
return cn.usableAtomic.Load()
|
return cn.Usable.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setUsable sets the usable flag atomically (lock-free).
|
// setUsable sets the usable flag atomically (lock-free).
|
||||||
func (cn *Conn) setUsable(usable bool) {
|
func (cn *Conn) setUsable(usable bool) {
|
||||||
cn.usableAtomic.Store(usable)
|
cn.Usable.Store(usable)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getHandoffState returns the current handoff state atomically (lock-free).
|
// getHandoffState returns the current handoff state atomically (lock-free).
|
||||||
@@ -456,6 +457,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
|||||||
const baseDelay = time.Microsecond
|
const baseDelay = time.Microsecond
|
||||||
|
|
||||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
// first we need to mark the connection as not usable
|
||||||
|
// to prevent the pool from returning it to the caller
|
||||||
|
if !cn.Usable.CompareAndSwap(true, false) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
currentState := cn.getHandoffState()
|
currentState := cn.getHandoffState()
|
||||||
|
|
||||||
// Check if marked for handoff
|
// Check if marked for handoff
|
||||||
@@ -472,7 +479,6 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
|||||||
|
|
||||||
// Atomic compare-and-swap to update state
|
// Atomic compare-and-swap to update state
|
||||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||||
cn.setUsable(false)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -24,6 +24,9 @@ var (
|
|||||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||||
|
|
||||||
|
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
|
||||||
|
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
|
||||||
|
|
||||||
// popAttempts is the maximum number of attempts to find a usable connection
|
// popAttempts is the maximum number of attempts to find a usable connection
|
||||||
// when popping from the idle connection pool. This handles cases where connections
|
// when popping from the idle connection pool. This handles cases where connections
|
||||||
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
||||||
|
60
redis.go
60
redis.go
@@ -224,6 +224,9 @@ type baseClient struct {
|
|||||||
// Maintenance notifications manager
|
// Maintenance notifications manager
|
||||||
maintNotificationsManager *maintnotifications.Manager
|
maintNotificationsManager *maintnotifications.Manager
|
||||||
maintNotificationsManagerLock sync.RWMutex
|
maintNotificationsManagerLock sync.RWMutex
|
||||||
|
|
||||||
|
credListeners map[uint64]auth.CredentialsListener
|
||||||
|
credListenersLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) clone() *baseClient {
|
func (c *baseClient) clone() *baseClient {
|
||||||
@@ -237,6 +240,7 @@ func (c *baseClient) clone() *baseClient {
|
|||||||
onClose: c.onClose,
|
onClose: c.onClose,
|
||||||
pushProcessor: c.pushProcessor,
|
pushProcessor: c.pushProcessor,
|
||||||
maintNotificationsManager: maintNotificationsManager,
|
maintNotificationsManager: maintNotificationsManager,
|
||||||
|
credListeners: c.credListeners,
|
||||||
}
|
}
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
@@ -296,18 +300,43 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
|||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
// connReAuthCredentialsListener returns a credentials listener that can be used to re-authenticate the connection.
|
||||||
return auth.NewReAuthCredentialsListener(
|
// The credentials listener is stored in a map, so that it can be reused for multiple connections.
|
||||||
c.reAuthConnection(poolCn),
|
// The credentials listener is removed from the map when the connection is closed.
|
||||||
c.onAuthenticationErr(poolCn),
|
func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.CredentialsListener, func()) {
|
||||||
|
c.credListenersLock.RLock()
|
||||||
|
credListener, ok := c.credListeners[poolCn.GetID()]
|
||||||
|
c.credListenersLock.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return credListener.(auth.CredentialsListener), func() {
|
||||||
|
c.removeCredListener(poolCn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.credListenersLock.Lock()
|
||||||
|
defer c.credListenersLock.Unlock()
|
||||||
|
newCredListener := auth.NewConnReAuthCredentialsListener(
|
||||||
|
poolCn,
|
||||||
|
c.reAuthConnection(),
|
||||||
|
c.onAuthenticationErr(),
|
||||||
)
|
)
|
||||||
|
c.credListeners[poolCn.GetID()] = newCredListener
|
||||||
|
return newCredListener, func() {
|
||||||
|
c.removeCredListener(poolCn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
func (c *baseClient) removeCredListener(poolCn *pool.Conn) {
|
||||||
return func(credentials auth.Credentials) error {
|
c.credListenersLock.Lock()
|
||||||
|
defer c.credListenersLock.Unlock()
|
||||||
|
delete(c.credListeners, poolCn.GetID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||||
|
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||||
var err error
|
var err error
|
||||||
username, password := credentials.BasicAuth()
|
username, password := credentials.BasicAuth()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||||
// hooksMixin are intentionally empty here
|
// hooksMixin are intentionally empty here
|
||||||
cn := newConn(c.opt, connPool, nil)
|
cn := newConn(c.opt, connPool, nil)
|
||||||
@@ -320,8 +349,8 @@ func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.C
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
|
||||||
return func(err error) {
|
return func(poolCn *pool.Conn, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isBadConn(err, false, c.opt.Addr) {
|
if isBadConn(err, false, c.opt.Addr) {
|
||||||
// Close the connection to force a reconnection.
|
// Close the connection to force a reconnection.
|
||||||
@@ -372,13 +401,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
|
|
||||||
username, password := "", ""
|
username, password := "", ""
|
||||||
if c.opt.StreamingCredentialsProvider != nil {
|
if c.opt.StreamingCredentialsProvider != nil {
|
||||||
|
credListener, removeCredListener := c.connReAuthCredentialsListener(cn)
|
||||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
Subscribe(credListener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||||
}
|
}
|
||||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
|
||||||
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
unsubscribe := func() error {
|
||||||
|
removeCredListener()
|
||||||
|
return unsubscribeFromCredentialsProvider()
|
||||||
|
}
|
||||||
|
c.onClose = c.wrappedOnClose(unsubscribe)
|
||||||
|
cn.SetOnClose(unsubscribe)
|
||||||
|
|
||||||
username, password = credentials.BasicAuth()
|
username, password = credentials.BasicAuth()
|
||||||
} else if c.opt.CredentialsProviderContext != nil {
|
} else if c.opt.CredentialsProviderContext != nil {
|
||||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||||
@@ -496,6 +532,8 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mark the connection as usable and inited
|
||||||
|
// once returned to the pool as idle, this connection can be used by other clients
|
||||||
cn.SetUsable(true)
|
cn.SetUsable(true)
|
||||||
cn.Inited.Store(true)
|
cn.Inited.Store(true)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user