1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-18 22:08:50 +03:00

fix nil listener

This commit is contained in:
Nedyalko Dyakov
2025-10-17 14:07:26 +03:00
parent 0e10cd7cd2
commit afba8c285f
4 changed files with 47 additions and 29 deletions

View File

@@ -59,19 +59,5 @@ func (c *ConnReAuthCredentialsListener) OnError(err error) {
c.onErr(c.conn, err) c.onErr(c.conn, err)
} }
// newConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
// Implements the auth.CredentialsListener interface.
func newConnReAuthCredentialsListener(
conn *pool.Conn,
reAuth func(conn *pool.Conn, credentials auth.Credentials) error,
onErr func(conn *pool.Conn, err error),
) *ConnReAuthCredentialsListener {
return &ConnReAuthCredentialsListener{
conn: conn,
reAuth: reAuth,
onErr: onErr,
}
}
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. // Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil) var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)

View File

@@ -1,6 +1,7 @@
package streaming package streaming
import ( import (
"errors"
"time" "time"
"github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/auth"
@@ -29,19 +30,24 @@ func (m *Manager) Listener(
poolCn *pool.Conn, poolCn *pool.Conn,
reAuth func(*pool.Conn, auth.Credentials) error, reAuth func(*pool.Conn, auth.Credentials) error,
onErr func(*pool.Conn, error), onErr func(*pool.Conn, error),
) auth.CredentialsListener { ) (auth.CredentialsListener, error) {
if poolCn == nil {
return nil, errors.New("poolCn cannot be nil")
}
connID := poolCn.GetID() connID := poolCn.GetID()
listener, ok := m.credentialsListeners.Get(connID) listener, ok := m.credentialsListeners.Get(connID)
if !ok { if !ok || listener == nil {
newCredListener := newConnReAuthCredentialsListener( newCredListener := &ConnReAuthCredentialsListener{
poolCn, conn: poolCn,
reAuth, reAuth: reAuth,
onErr, onErr: onErr,
) manager: m,
newCredListener.manager = m
m.credentialsListeners.Add(connID, newCredListener)
} }
return listener
m.credentialsListeners.Add(connID, newCredListener)
listener = newCredListener
}
return listener, nil
} }
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {

View File

@@ -373,11 +373,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
username, password := "", "" username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil { if c.opt.StreamingCredentialsProvider != nil {
credListener := c.streamingCredentialsManager.Listener( credListener, err := c.streamingCredentialsManager.Listener(
cn, cn,
c.reAuthConnection(), c.reAuthConnection(),
c.onAuthenticationErr(), c.onAuthenticationErr(),
) )
if err != nil {
return fmt.Errorf("failed to create credentials listener: %w", err)
}
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(credListener) Subscribe(credListener)

View File

@@ -883,14 +883,36 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
return nil, nil, m.err return nil, nil, m.err
} }
if listener == nil {
return nil, nil, errors.New("listener cannot be nil")
}
// Create a done channel to stop the goroutine
done := make(chan struct{})
// Start goroutine to handle updates // Start goroutine to handle updates
go func() { go func() {
for creds := range m.updates { defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow panics to be caught without crashing
}
}()
for {
select {
case <-done:
return
case creds, ok := <-m.updates:
if !ok {
return
}
m.mu.Lock() m.mu.Lock()
m.credentials = creds m.credentials = creds
m.mu.Unlock() m.mu.Unlock()
listener.OnNext(creds) listener.OnNext(creds)
} }
}
}() }()
m.mu.RLock() m.mu.RLock()
@@ -904,6 +926,7 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
// allow multiple closes from multiple listeners // allow multiple closes from multiple listeners
} }
}() }()
close(done)
return return
}, nil }, nil
} }