mirror of
https://github.com/redis/go-redis.git
synced 2025-10-18 22:08:50 +03:00
fix nil listener
This commit is contained in:
@@ -59,19 +59,5 @@ func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
|||||||
c.onErr(c.conn, err)
|
c.onErr(c.conn, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
|
|
||||||
// Implements the auth.CredentialsListener interface.
|
|
||||||
func newConnReAuthCredentialsListener(
|
|
||||||
conn *pool.Conn,
|
|
||||||
reAuth func(conn *pool.Conn, credentials auth.Credentials) error,
|
|
||||||
onErr func(conn *pool.Conn, err error),
|
|
||||||
) *ConnReAuthCredentialsListener {
|
|
||||||
return &ConnReAuthCredentialsListener{
|
|
||||||
conn: conn,
|
|
||||||
reAuth: reAuth,
|
|
||||||
onErr: onErr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||||
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package streaming
|
package streaming
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9/auth"
|
"github.com/redis/go-redis/v9/auth"
|
||||||
@@ -29,19 +30,24 @@ func (m *Manager) Listener(
|
|||||||
poolCn *pool.Conn,
|
poolCn *pool.Conn,
|
||||||
reAuth func(*pool.Conn, auth.Credentials) error,
|
reAuth func(*pool.Conn, auth.Credentials) error,
|
||||||
onErr func(*pool.Conn, error),
|
onErr func(*pool.Conn, error),
|
||||||
) auth.CredentialsListener {
|
) (auth.CredentialsListener, error) {
|
||||||
|
if poolCn == nil {
|
||||||
|
return nil, errors.New("poolCn cannot be nil")
|
||||||
|
}
|
||||||
connID := poolCn.GetID()
|
connID := poolCn.GetID()
|
||||||
listener, ok := m.credentialsListeners.Get(connID)
|
listener, ok := m.credentialsListeners.Get(connID)
|
||||||
if !ok {
|
if !ok || listener == nil {
|
||||||
newCredListener := newConnReAuthCredentialsListener(
|
newCredListener := &ConnReAuthCredentialsListener{
|
||||||
poolCn,
|
conn: poolCn,
|
||||||
reAuth,
|
reAuth: reAuth,
|
||||||
onErr,
|
onErr: onErr,
|
||||||
)
|
manager: m,
|
||||||
newCredListener.manager = m
|
|
||||||
m.credentialsListeners.Add(connID, newCredListener)
|
|
||||||
}
|
}
|
||||||
return listener
|
|
||||||
|
m.credentialsListeners.Add(connID, newCredListener)
|
||||||
|
listener = newCredListener
|
||||||
|
}
|
||||||
|
return listener, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
|
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
|
||||||
|
5
redis.go
5
redis.go
@@ -373,11 +373,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
|
|
||||||
username, password := "", ""
|
username, password := "", ""
|
||||||
if c.opt.StreamingCredentialsProvider != nil {
|
if c.opt.StreamingCredentialsProvider != nil {
|
||||||
credListener := c.streamingCredentialsManager.Listener(
|
credListener, err := c.streamingCredentialsManager.Listener(
|
||||||
cn,
|
cn,
|
||||||
c.reAuthConnection(),
|
c.reAuthConnection(),
|
||||||
c.onAuthenticationErr(),
|
c.onAuthenticationErr(),
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create credentials listener: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||||
Subscribe(credListener)
|
Subscribe(credListener)
|
||||||
|
@@ -883,14 +883,36 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
|
|||||||
return nil, nil, m.err
|
return nil, nil, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if listener == nil {
|
||||||
|
return nil, nil, errors.New("listener cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a done channel to stop the goroutine
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
// Start goroutine to handle updates
|
// Start goroutine to handle updates
|
||||||
go func() {
|
go func() {
|
||||||
for creds := range m.updates {
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// this is just a mock:
|
||||||
|
// allow panics to be caught without crashing
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case creds, ok := <-m.updates:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.credentials = creds
|
m.credentials = creds
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
listener.OnNext(creds)
|
listener.OnNext(creds)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
@@ -904,6 +926,7 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
|
|||||||
// allow multiple closes from multiple listeners
|
// allow multiple closes from multiple listeners
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
close(done)
|
||||||
return
|
return
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user