1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-18 22:08:50 +03:00

fix nil listener

This commit is contained in:
Nedyalko Dyakov
2025-10-17 14:07:26 +03:00
parent 0e10cd7cd2
commit afba8c285f
4 changed files with 47 additions and 29 deletions

View File

@@ -59,19 +59,5 @@ func (c *ConnReAuthCredentialsListener) OnError(err error) {
c.onErr(c.conn, err)
}
// newConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
// Implements the auth.CredentialsListener interface.
func newConnReAuthCredentialsListener(
conn *pool.Conn,
reAuth func(conn *pool.Conn, credentials auth.Credentials) error,
onErr func(conn *pool.Conn, err error),
) *ConnReAuthCredentialsListener {
return &ConnReAuthCredentialsListener{
conn: conn,
reAuth: reAuth,
onErr: onErr,
}
}
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)

View File

@@ -1,6 +1,7 @@
package streaming
import (
"errors"
"time"
"github.com/redis/go-redis/v9/auth"
@@ -29,19 +30,24 @@ func (m *Manager) Listener(
poolCn *pool.Conn,
reAuth func(*pool.Conn, auth.Credentials) error,
onErr func(*pool.Conn, error),
) auth.CredentialsListener {
) (auth.CredentialsListener, error) {
if poolCn == nil {
return nil, errors.New("poolCn cannot be nil")
}
connID := poolCn.GetID()
listener, ok := m.credentialsListeners.Get(connID)
if !ok {
newCredListener := newConnReAuthCredentialsListener(
poolCn,
reAuth,
onErr,
)
newCredListener.manager = m
m.credentialsListeners.Add(connID, newCredListener)
if !ok || listener == nil {
newCredListener := &ConnReAuthCredentialsListener{
conn: poolCn,
reAuth: reAuth,
onErr: onErr,
manager: m,
}
return listener
m.credentialsListeners.Add(connID, newCredListener)
listener = newCredListener
}
return listener, nil
}
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {

View File

@@ -373,11 +373,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credListener := c.streamingCredentialsManager.Listener(
credListener, err := c.streamingCredentialsManager.Listener(
cn,
c.reAuthConnection(),
c.onAuthenticationErr(),
)
if err != nil {
return fmt.Errorf("failed to create credentials listener: %w", err)
}
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(credListener)

View File

@@ -883,14 +883,36 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
return nil, nil, m.err
}
if listener == nil {
return nil, nil, errors.New("listener cannot be nil")
}
// Create a done channel to stop the goroutine
done := make(chan struct{})
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow panics to be caught without crashing
}
}()
for {
select {
case <-done:
return
case creds, ok := <-m.updates:
if !ok {
return
}
m.mu.Lock()
m.credentials = creds
m.mu.Unlock()
listener.OnNext(creds)
}
}
}()
m.mu.RLock()
@@ -904,6 +926,7 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
// allow multiple closes from multiple listeners
}
}()
close(done)
return
}, nil
}