From afba8c285fc06f5f757fe68fea9a26fb05bcfe21 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 17 Oct 2025 14:07:26 +0300 Subject: [PATCH] fix nil listener --- .../conn_reauth_credentials_listener.go | 14 -------- internal/auth/streaming/manager.go | 24 +++++++++----- redis.go | 5 ++- redis_test.go | 33 ++++++++++++++++--- 4 files changed, 47 insertions(+), 29 deletions(-) diff --git a/internal/auth/streaming/conn_reauth_credentials_listener.go b/internal/auth/streaming/conn_reauth_credentials_listener.go index a13a3be4..0f17818c 100644 --- a/internal/auth/streaming/conn_reauth_credentials_listener.go +++ b/internal/auth/streaming/conn_reauth_credentials_listener.go @@ -59,19 +59,5 @@ func (c *ConnReAuthCredentialsListener) OnError(err error) { c.onErr(c.conn, err) } -// newConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener. -// Implements the auth.CredentialsListener interface. -func newConnReAuthCredentialsListener( - conn *pool.Conn, - reAuth func(conn *pool.Conn, credentials auth.Credentials) error, - onErr func(conn *pool.Conn, err error), -) *ConnReAuthCredentialsListener { - return &ConnReAuthCredentialsListener{ - conn: conn, - reAuth: reAuth, - onErr: onErr, - } -} - // Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil) diff --git a/internal/auth/streaming/manager.go b/internal/auth/streaming/manager.go index e9834728..3f529d15 100644 --- a/internal/auth/streaming/manager.go +++ b/internal/auth/streaming/manager.go @@ -1,6 +1,7 @@ package streaming import ( + "errors" "time" "github.com/redis/go-redis/v9/auth" @@ -29,19 +30,24 @@ func (m *Manager) Listener( poolCn *pool.Conn, reAuth func(*pool.Conn, auth.Credentials) error, onErr func(*pool.Conn, error), -) auth.CredentialsListener { +) (auth.CredentialsListener, error) { + if poolCn == nil { + return nil, errors.New("poolCn cannot be nil") + } connID := poolCn.GetID() listener, ok := m.credentialsListeners.Get(connID) - if !ok { - newCredListener := newConnReAuthCredentialsListener( - poolCn, - reAuth, - onErr, - ) - newCredListener.manager = m + if !ok || listener == nil { + newCredListener := &ConnReAuthCredentialsListener{ + conn: poolCn, + reAuth: reAuth, + onErr: onErr, + manager: m, + } + m.credentialsListeners.Add(connID, newCredListener) + listener = newCredListener } - return listener + return listener, nil } func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { diff --git a/redis.go b/redis.go index d0667b50..93cdf2b0 100644 --- a/redis.go +++ b/redis.go @@ -373,11 +373,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { - credListener := c.streamingCredentialsManager.Listener( + credListener, err := c.streamingCredentialsManager.Listener( cn, c.reAuthConnection(), c.onAuthenticationErr(), ) + if err != nil { + return fmt.Errorf("failed to create credentials listener: %w", err) + } credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. Subscribe(credListener) diff --git a/redis_test.go b/redis_test.go index 0e1646d6..69489e78 100644 --- a/redis_test.go +++ b/redis_test.go @@ -883,13 +883,35 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au return nil, nil, m.err } + if listener == nil { + return nil, nil, errors.New("listener cannot be nil") + } + + // Create a done channel to stop the goroutine + done := make(chan struct{}) + // Start goroutine to handle updates go func() { - for creds := range m.updates { - m.mu.Lock() - m.credentials = creds - m.mu.Unlock() - listener.OnNext(creds) + defer func() { + if r := recover(); r != nil { + // this is just a mock: + // allow panics to be caught without crashing + } + }() + + for { + select { + case <-done: + return + case creds, ok := <-m.updates: + if !ok { + return + } + m.mu.Lock() + m.credentials = creds + m.mu.Unlock() + listener.OnNext(creds) + } } }() @@ -904,6 +926,7 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au // allow multiple closes from multiple listeners } }() + close(done) return }, nil }