From 44628c5dbd47868d8ad62626269a6cebfbba6bfd Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 22 Apr 2025 15:35:57 +0300 Subject: [PATCH] add tests --- auth/auth_test.go | 302 +++++++++++++++++++++++++++++++++++++++ command_recorder_test.go | 86 +++++++++++ internal/internal.go | 2 + osscluster_test.go | 3 + probabilistic_test.go | 6 +- redis.go | 79 ++++++---- redis_test.go | 169 ++++++++++++++++++++++ 7 files changed, 619 insertions(+), 28 deletions(-) create mode 100644 auth/auth_test.go create mode 100644 command_recorder_test.go diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 00000000..88835a48 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,302 @@ +package auth + +import ( + "errors" + "testing" + "time" +) + +type mockStreamingProvider struct { + credentials Credentials + err error + updates chan Credentials +} + +func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider { + return &mockStreamingProvider{ + credentials: initialCreds, + updates: make(chan Credentials, 10), + } +} + +func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) { + if m.err != nil { + return nil, nil, m.err + } + + // Send initial credentials + listener.OnNext(m.credentials) + + // Start goroutine to handle updates + go func() { + for creds := range m.updates { + listener.OnNext(creds) + } + }() + + return m.credentials, func() error { + close(m.updates) + return nil + }, nil +} + +func TestStreamingCredentialsProvider(t *testing.T) { + t.Run("successful subscription", func(t *testing.T) { + initialCreds := NewBasicCredentials("user1", "pass1") + provider := newMockStreamingProvider(initialCreds) + + var receivedCreds []Credentials + var receivedErrors []error + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + receivedCreds = append(receivedCreds, creds) + return nil + }, + func(err error) { + receivedErrors = append(receivedErrors, err) + }, + ) + + creds, cancel, err := provider.Subscribe(listener) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cancel == nil { + t.Fatal("expected cancel function to be non-nil") + } + if creds != initialCreds { + t.Fatalf("expected credentials %v, got %v", initialCreds, creds) + } + if len(receivedCreds) != 1 { + t.Fatalf("expected 1 received credential, got %d", len(receivedCreds)) + } + if receivedCreds[0] != initialCreds { + t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0]) + } + if len(receivedErrors) != 0 { + t.Fatalf("expected no errors, got %d", len(receivedErrors)) + } + + // Send an update + newCreds := NewBasicCredentials("user2", "pass2") + provider.updates <- newCreds + + // Wait for update to be processed + time.Sleep(100 * time.Millisecond) + if len(receivedCreds) != 2 { + t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds)) + } + if receivedCreds[1] != newCreds { + t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1]) + } + + // Cancel subscription + if err := cancel(); err != nil { + t.Fatalf("unexpected error cancelling subscription: %v", err) + } + }) + + t.Run("subscription error", func(t *testing.T) { + provider := &mockStreamingProvider{ + err: errors.New("subscription failed"), + } + + var receivedCreds []Credentials + var receivedErrors []error + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + receivedCreds = append(receivedCreds, creds) + return nil + }, + func(err error) { + receivedErrors = append(receivedErrors, err) + }, + ) + + creds, cancel, err := provider.Subscribe(listener) + if err == nil { + t.Fatal("expected error, got nil") + } + if cancel != nil { + t.Fatal("expected cancel function to be nil") + } + if creds != nil { + t.Fatalf("expected nil credentials, got %v", creds) + } + if len(receivedCreds) != 0 { + t.Fatalf("expected no received credentials, got %d", len(receivedCreds)) + } + if len(receivedErrors) != 0 { + t.Fatalf("expected no errors, got %d", len(receivedErrors)) + } + }) + + t.Run("re-auth error", func(t *testing.T) { + initialCreds := NewBasicCredentials("user1", "pass1") + provider := newMockStreamingProvider(initialCreds) + + reauthErr := errors.New("re-auth failed") + var receivedErrors []error + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + return reauthErr + }, + func(err error) { + receivedErrors = append(receivedErrors, err) + }, + ) + + creds, cancel, err := provider.Subscribe(listener) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cancel == nil { + t.Fatal("expected cancel function to be non-nil") + } + if creds != initialCreds { + t.Fatalf("expected credentials %v, got %v", initialCreds, creds) + } + if len(receivedErrors) != 1 { + t.Fatalf("expected 1 error, got %d", len(receivedErrors)) + } + if receivedErrors[0] != reauthErr { + t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0]) + } + + if err := cancel(); err != nil { + t.Fatalf("unexpected error cancelling subscription: %v", err) + } + }) +} + +func TestBasicCredentials(t *testing.T) { + t.Run("basic auth", func(t *testing.T) { + creds := NewBasicCredentials("user1", "pass1") + username, password := creds.BasicAuth() + if username != "user1" { + t.Fatalf("expected username 'user1', got '%s'", username) + } + if password != "pass1" { + t.Fatalf("expected password 'pass1', got '%s'", password) + } + }) + + t.Run("raw credentials", func(t *testing.T) { + creds := NewBasicCredentials("user1", "pass1") + raw := creds.RawCredentials() + expected := "user1:pass1" + if raw != expected { + t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw) + } + }) + + t.Run("empty username", func(t *testing.T) { + creds := NewBasicCredentials("", "pass1") + username, password := creds.BasicAuth() + if username != "" { + t.Fatalf("expected empty username, got '%s'", username) + } + if password != "pass1" { + t.Fatalf("expected password 'pass1', got '%s'", password) + } + }) +} + +func TestReAuthCredentialsListener(t *testing.T) { + t.Run("successful re-auth", func(t *testing.T) { + var reAuthCalled bool + var onErrCalled bool + var receivedCreds Credentials + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + reAuthCalled = true + receivedCreds = creds + return nil + }, + func(err error) { + onErrCalled = true + }, + ) + + creds := NewBasicCredentials("user1", "pass1") + listener.OnNext(creds) + + if !reAuthCalled { + t.Fatal("expected reAuth to be called") + } + if onErrCalled { + t.Fatal("expected onErr not to be called") + } + if receivedCreds != creds { + t.Fatalf("expected credentials %v, got %v", creds, receivedCreds) + } + }) + + t.Run("re-auth error", func(t *testing.T) { + var reAuthCalled bool + var onErrCalled bool + var receivedErr error + expectedErr := errors.New("re-auth failed") + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + reAuthCalled = true + return expectedErr + }, + func(err error) { + onErrCalled = true + receivedErr = err + }, + ) + + creds := NewBasicCredentials("user1", "pass1") + listener.OnNext(creds) + + if !reAuthCalled { + t.Fatal("expected reAuth to be called") + } + if !onErrCalled { + t.Fatal("expected onErr to be called") + } + if receivedErr != expectedErr { + t.Fatalf("expected error %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("on error", func(t *testing.T) { + var onErrCalled bool + var receivedErr error + expectedErr := errors.New("provider error") + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + return nil + }, + func(err error) { + onErrCalled = true + receivedErr = err + }, + ) + + listener.OnError(expectedErr) + + if !onErrCalled { + t.Fatal("expected onErr to be called") + } + if receivedErr != expectedErr { + t.Fatalf("expected error %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("nil callbacks", func(t *testing.T) { + listener := NewReAuthCredentialsListener(nil, nil) + + // Should not panic + listener.OnNext(NewBasicCredentials("user1", "pass1")) + listener.OnError(errors.New("test error")) + }) +} diff --git a/command_recorder_test.go b/command_recorder_test.go new file mode 100644 index 00000000..2251df5e --- /dev/null +++ b/command_recorder_test.go @@ -0,0 +1,86 @@ +package redis_test + +import ( + "context" + "strings" + "sync" + + "github.com/redis/go-redis/v9" +) + +// commandRecorder records the last N commands executed by a Redis client. +type commandRecorder struct { + mu sync.Mutex + commands []string + maxSize int +} + +// newCommandRecorder creates a new command recorder with the specified maximum size. +func newCommandRecorder(maxSize int) *commandRecorder { + return &commandRecorder{ + commands: make([]string, 0, maxSize), + maxSize: maxSize, + } +} + +// Record adds a command to the recorder. +func (r *commandRecorder) Record(cmd string) { + cmd = strings.ToLower(cmd) + r.mu.Lock() + defer r.mu.Unlock() + + r.commands = append(r.commands, cmd) + if len(r.commands) > r.maxSize { + r.commands = r.commands[1:] + } +} + +// LastCommands returns a copy of the recorded commands. +func (r *commandRecorder) LastCommands() []string { + r.mu.Lock() + defer r.mu.Unlock() + return append([]string(nil), r.commands...) +} + +// Contains checks if the recorder contains a specific command. +func (r *commandRecorder) Contains(cmd string) bool { + cmd = strings.ToLower(cmd) + r.mu.Lock() + defer r.mu.Unlock() + for _, c := range r.commands { + if strings.Contains(c, cmd) { + return true + } + } + return false +} + +// Hook returns a Redis hook that records commands. +func (r *commandRecorder) Hook() redis.Hook { + return &commandHook{recorder: r} +} + +// commandHook implements the redis.Hook interface to record commands. +type commandHook struct { + recorder *commandRecorder +} + +func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook { + return next +} + +func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + h.recorder.Record(cmd.String()) + return next(ctx, cmd) + } +} + +func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + for _, cmd := range cmds { + h.recorder.Record(cmd.String()) + } + return next(ctx, cmds) + } +} diff --git a/internal/internal.go b/internal/internal.go index e783d139..dbf77e26 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -6,6 +6,8 @@ import ( "github.com/redis/go-redis/v9/internal/rand" ) +type ParentHooksMixinKey struct{} + func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { if retry < 0 { panic("not reached") diff --git a/osscluster_test.go b/osscluster_test.go index ccf6daad..6e214a71 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient( func (s *clusterScenario) Close() error { ctx := context.TODO() for _, master := range s.masters() { + if master == nil { + continue + } err := master.FlushAll(ctx).Err() if err != nil { return err diff --git a/probabilistic_test.go b/probabilistic_test.go index a0a050e2..0a3f1a15 100644 --- a/probabilistic_test.go +++ b/probabilistic_test.go @@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() { }) It("should CFCount", Label("cuckoo", "cfcount"), func() { - err := client.CFAdd(ctx, "testcf1", "item1").Err() + client.CFAdd(ctx, "testcf1", "item1") cnt, err := client.CFCount(ctx, "testcf1", "item1").Result() Expect(err).NotTo(HaveOccurred()) Expect(cnt).To(BeEquivalentTo(int64(1))) @@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() { NoCreate: true, } - result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() + _, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() Expect(err).To(HaveOccurred()) args = &redis.CFInsertOptions{ @@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() { NoCreate: false, } - result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() + result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() Expect(err).NotTo(HaveOccurred()) Expect(len(result)).To(BeEquivalentTo(3)) }) diff --git a/redis.go b/redis.go index 94de3fc7..cbc8e60c 100644 --- a/redis.go +++ b/redis.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "net" "sync" "sync/atomic" @@ -308,8 +309,15 @@ func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err // we can get it from the *Conn and remove it from the clients pool. if err != nil { if isBadConn(err, false, c.opt.Addr) { - poolCn, _ := cn.connPool.Get(ctx) - c.connPool.Remove(ctx, poolCn, err) + poolCn, getErr := cn.connPool.Get(ctx) + if getErr == nil { + c.connPool.Remove(ctx, poolCn, err) + } else { + // if we can't get the pool connection, we can only close the connection + if err := cn.Close(); err != nil { + log.Printf("failed to close connection: %v", err) + } + } } } } @@ -344,7 +352,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { var err error cn.Inited = true connPool := pool.NewSingleConnPool(c.connPool, cn) - conn := newConn(c.opt, connPool) + var parentHooks hooksMixin + pH := ctx.Value(internal.ParentHooksMixinKey{}) + switch pH := pH.(type) { + case nil: + parentHooks = hooksMixin{} + case hooksMixin: + parentHooks = pH.clone() + case *hooksMixin: + parentHooks = (*pH).clone() + default: + parentHooks = hooksMixin{} + } + + conn := newConn(c.opt, connPool, parentHooks) protocol := c.opt.Protocol // By default, use RESP3 in current version. @@ -352,28 +373,30 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { protocol = 3 } - var authenticated bool - username, password := c.opt.Username, c.opt.Password + username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider. Subscribe(c.newReAuthCredentialsListener(ctx, conn)) if err != nil { - return err + return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) } c.onClose = c.wrappedOnClose(cancelCredentialsProvider) username, password = credentials.BasicAuth() } else if c.opt.CredentialsProviderContext != nil { - if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { - return err + username, password, err = c.opt.CredentialsProviderContext(ctx) + if err != nil { + return fmt.Errorf("failed to get credentials from context provider: %w", err) } } else if c.opt.CredentialsProvider != nil { username, password = c.opt.CredentialsProvider() + } else if c.opt.Username != "" || c.opt.Password != "" { + username, password = c.opt.Username, c.opt.Password } // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil { - authenticated = true + // Authentication successful with HELLO command } else if !isRedisError(err) { // When the server responds with the RESP protocol and the result is not a normal // execution result of the HELLO command, we consider it to be an indication that @@ -382,15 +405,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // or it could be DragonflyDB or a third-party redis-proxy. They all respond // with different error string results for unsupported commands, making it // difficult to rely on error strings to determine all results. - return err - } - - if !authenticated && password != "" { + return fmt.Errorf("failed to initialize connection: %w", err) + } else if password != "" { + // Try legacy AUTH command if HELLO failed err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password)) if err != nil { - return err + return fmt.Errorf("failed to authenticate: %w", err) } } + _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) @@ -407,7 +430,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil }) if err != nil { - return err + return fmt.Errorf("failed to initialize connection options: %w", err) } if !c.opt.DisableIdentity && !c.opt.DisableIndentity { @@ -422,13 +445,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid // out of order responses later on. if _, err = p.Exec(ctx); err != nil && !isRedisError(err) { - return err + return fmt.Errorf("failed to set client identity: %w", err) } } if c.opt.OnConnect != nil { return c.opt.OnConnect(ctx, conn) } + return nil } @@ -547,6 +571,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { return c.opt.ReadTimeout } +// context returns the context for the current connection. +// If the context timeout is enabled, it returns the original context. +// Otherwise, it returns a new background context. +func (c *baseClient) context(ctx context.Context) context.Context { + if c.opt.ContextTimeoutEnabled { + return ctx + } + return context.Background() +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be @@ -699,13 +733,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) return nil } -func (c *baseClient) context(ctx context.Context) context.Context { - if c.opt.ContextTimeoutEnabled { - return ctx - } - return context.Background() -} - //------------------------------------------------------------------------------ // Client is a Redis client representing a pool of zero or more underlying connections. @@ -752,7 +779,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client { } func (c *Client) Conn() *Conn { - return newConn(c.opt, pool.NewStickyConnPool(c.connPool)) + return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.hooksMixin.clone()) } // Do create a Cmd from the args and processes the cmd. @@ -763,6 +790,7 @@ func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *Client) Process(ctx context.Context, cmd Cmder) error { + ctx = context.WithValue(ctx, internal.ParentHooksMixinKey{}, c.hooksMixin) err := c.processHook(ctx, cmd) cmd.SetErr(err) return err @@ -888,7 +916,7 @@ type Conn struct { hooksMixin } -func newConn(opt *Options, connPool pool.Pooler) *Conn { +func newConn(opt *Options, connPool pool.Pooler, parentHooks hooksMixin) *Conn { c := Conn{ baseClient: baseClient{ opt: opt, @@ -898,6 +926,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn { c.cmdable = c.Process c.statefulCmdable = c.Process + c.hooksMixin = parentHooks c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, diff --git a/redis_test.go b/redis_test.go index 7d9bf1ce..089973e0 100644 --- a/redis_test.go +++ b/redis_test.go @@ -14,6 +14,7 @@ import ( . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/auth" ) type redisHookError struct{} @@ -727,3 +728,171 @@ var _ = Describe("Dialer connection timeouts", func() { Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay)) }) }) + +var _ = Describe("Credentials Provider Priority", func() { + var client *redis.Client + var opt *redis.Options + var recorder *commandRecorder + + BeforeEach(func() { + recorder = newCommandRecorder(10) + }) + + AfterEach(func() { + if client != nil { + Expect(client.Close()).NotTo(HaveOccurred()) + } + }) + + It("should use streaming provider when available", func() { + streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass") + ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass") + providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass") + + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + CredentialsProvider: func() (string, string) { + username, password := providerCreds.BasicAuth() + return username, password + }, + CredentialsProviderContext: func(ctx context.Context) (string, string, error) { + username, password := ctxCreds.BasicAuth() + return username, password, nil + }, + StreamingCredentialsProvider: &mockStreamingProvider{ + credentials: streamingCreds, + updates: make(chan auth.Credentials, 1), + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue()) + }) + + It("should use context provider when streaming provider is not available", func() { + ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass") + providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass") + + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + CredentialsProvider: func() (string, string) { + username, password := providerCreds.BasicAuth() + return username, password + }, + CredentialsProviderContext: func(ctx context.Context) (string, string, error) { + username, password := ctxCreds.BasicAuth() + return username, password, nil + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue()) + }) + + It("should use regular provider when streaming and context providers are not available", func() { + providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass") + + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + CredentialsProvider: func() (string, string) { + username, password := providerCreds.BasicAuth() + return username, password + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH provider_user")).To(BeTrue()) + }) + + It("should use username/password fields when no providers are set", func() { + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH field_user")).To(BeTrue()) + }) + + It("should use empty credentials when nothing is set", func() { + opt = &redis.Options{} + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // no pass, ok + Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred()) + Expect(recorder.Contains("AUTH")).To(BeFalse()) + }) + + It("should handle credential updates from streaming provider", func() { + initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass") + updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass") + + opt = &redis.Options{ + StreamingCredentialsProvider: &mockStreamingProvider{ + credentials: initialCreds, + updates: make(chan auth.Credentials, 1), + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH initial_user")).To(BeTrue()) + + // Update credentials + opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH updated_user")).To(BeTrue()) + }) +}) + +type mockStreamingProvider struct { + credentials auth.Credentials + err error + updates chan auth.Credentials +} + +func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { + if m.err != nil { + return nil, nil, m.err + } + + // Send initial credentials + listener.OnNext(m.credentials) + + // Start goroutine to handle updates + go func() { + for creds := range m.updates { + listener.OnNext(creds) + } + }() + + return m.credentials, func() (err error) { + defer func() { + if r := recover(); r != nil { + // this is just a mock: + // allow multiple closes from multiple listeners + } + }() + close(m.updates) + return + }, nil +}