mirror of
https://github.com/redis/go-redis.git
synced 2025-10-18 22:08:50 +03:00
fix nil listener
This commit is contained in:
@@ -59,19 +59,5 @@ func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
||||
c.onErr(c.conn, err)
|
||||
}
|
||||
|
||||
// newConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
|
||||
// Implements the auth.CredentialsListener interface.
|
||||
func newConnReAuthCredentialsListener(
|
||||
conn *pool.Conn,
|
||||
reAuth func(conn *pool.Conn, credentials auth.Credentials) error,
|
||||
onErr func(conn *pool.Conn, err error),
|
||||
) *ConnReAuthCredentialsListener {
|
||||
return &ConnReAuthCredentialsListener{
|
||||
conn: conn,
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
@@ -29,19 +30,24 @@ func (m *Manager) Listener(
|
||||
poolCn *pool.Conn,
|
||||
reAuth func(*pool.Conn, auth.Credentials) error,
|
||||
onErr func(*pool.Conn, error),
|
||||
) auth.CredentialsListener {
|
||||
) (auth.CredentialsListener, error) {
|
||||
if poolCn == nil {
|
||||
return nil, errors.New("poolCn cannot be nil")
|
||||
}
|
||||
connID := poolCn.GetID()
|
||||
listener, ok := m.credentialsListeners.Get(connID)
|
||||
if !ok {
|
||||
newCredListener := newConnReAuthCredentialsListener(
|
||||
poolCn,
|
||||
reAuth,
|
||||
onErr,
|
||||
)
|
||||
newCredListener.manager = m
|
||||
m.credentialsListeners.Add(connID, newCredListener)
|
||||
if !ok || listener == nil {
|
||||
newCredListener := &ConnReAuthCredentialsListener{
|
||||
conn: poolCn,
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
manager: m,
|
||||
}
|
||||
return listener
|
||||
|
||||
m.credentialsListeners.Add(connID, newCredListener)
|
||||
listener = newCredListener
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
|
||||
|
5
redis.go
5
redis.go
@@ -373,11 +373,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credListener := c.streamingCredentialsManager.Listener(
|
||||
credListener, err := c.streamingCredentialsManager.Listener(
|
||||
cn,
|
||||
c.reAuthConnection(),
|
||||
c.onAuthenticationErr(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create credentials listener: %w", err)
|
||||
}
|
||||
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(credListener)
|
||||
|
@@ -883,14 +883,36 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
if listener == nil {
|
||||
return nil, nil, errors.New("listener cannot be nil")
|
||||
}
|
||||
|
||||
// Create a done channel to stop the goroutine
|
||||
done := make(chan struct{})
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// this is just a mock:
|
||||
// allow panics to be caught without crashing
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case creds, ok := <-m.updates:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.credentials = creds
|
||||
m.mu.Unlock()
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
@@ -904,6 +926,7 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
|
||||
// allow multiple closes from multiple listeners
|
||||
}
|
||||
}()
|
||||
close(done)
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user