mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
add tests
This commit is contained in:
302
auth/auth_test.go
Normal file
302
auth/auth_test.go
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStreamingProvider struct {
|
||||||
|
credentials Credentials
|
||||||
|
err error
|
||||||
|
updates chan Credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
|
||||||
|
return &mockStreamingProvider{
|
||||||
|
credentials: initialCreds,
|
||||||
|
updates: make(chan Credentials, 10),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, nil, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial credentials
|
||||||
|
listener.OnNext(m.credentials)
|
||||||
|
|
||||||
|
// Start goroutine to handle updates
|
||||||
|
go func() {
|
||||||
|
for creds := range m.updates {
|
||||||
|
listener.OnNext(creds)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return m.credentials, func() error {
|
||||||
|
close(m.updates)
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamingCredentialsProvider(t *testing.T) {
|
||||||
|
t.Run("successful subscription", func(t *testing.T) {
|
||||||
|
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||||
|
provider := newMockStreamingProvider(initialCreds)
|
||||||
|
|
||||||
|
var receivedCreds []Credentials
|
||||||
|
var receivedErrors []error
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
receivedCreds = append(receivedCreds, creds)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
receivedErrors = append(receivedErrors, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds, cancel, err := provider.Subscribe(listener)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cancel == nil {
|
||||||
|
t.Fatal("expected cancel function to be non-nil")
|
||||||
|
}
|
||||||
|
if creds != initialCreds {
|
||||||
|
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||||
|
}
|
||||||
|
if len(receivedCreds) != 1 {
|
||||||
|
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
|
||||||
|
}
|
||||||
|
if receivedCreds[0] != initialCreds {
|
||||||
|
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
|
||||||
|
}
|
||||||
|
if len(receivedErrors) != 0 {
|
||||||
|
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send an update
|
||||||
|
newCreds := NewBasicCredentials("user2", "pass2")
|
||||||
|
provider.updates <- newCreds
|
||||||
|
|
||||||
|
// Wait for update to be processed
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if len(receivedCreds) != 2 {
|
||||||
|
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
|
||||||
|
}
|
||||||
|
if receivedCreds[1] != newCreds {
|
||||||
|
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel subscription
|
||||||
|
if err := cancel(); err != nil {
|
||||||
|
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subscription error", func(t *testing.T) {
|
||||||
|
provider := &mockStreamingProvider{
|
||||||
|
err: errors.New("subscription failed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var receivedCreds []Credentials
|
||||||
|
var receivedErrors []error
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
receivedCreds = append(receivedCreds, creds)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
receivedErrors = append(receivedErrors, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds, cancel, err := provider.Subscribe(listener)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if cancel != nil {
|
||||||
|
t.Fatal("expected cancel function to be nil")
|
||||||
|
}
|
||||||
|
if creds != nil {
|
||||||
|
t.Fatalf("expected nil credentials, got %v", creds)
|
||||||
|
}
|
||||||
|
if len(receivedCreds) != 0 {
|
||||||
|
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
|
||||||
|
}
|
||||||
|
if len(receivedErrors) != 0 {
|
||||||
|
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("re-auth error", func(t *testing.T) {
|
||||||
|
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||||
|
provider := newMockStreamingProvider(initialCreds)
|
||||||
|
|
||||||
|
reauthErr := errors.New("re-auth failed")
|
||||||
|
var receivedErrors []error
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
return reauthErr
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
receivedErrors = append(receivedErrors, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds, cancel, err := provider.Subscribe(listener)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cancel == nil {
|
||||||
|
t.Fatal("expected cancel function to be non-nil")
|
||||||
|
}
|
||||||
|
if creds != initialCreds {
|
||||||
|
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||||
|
}
|
||||||
|
if len(receivedErrors) != 1 {
|
||||||
|
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
|
||||||
|
}
|
||||||
|
if receivedErrors[0] != reauthErr {
|
||||||
|
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cancel(); err != nil {
|
||||||
|
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicCredentials(t *testing.T) {
|
||||||
|
t.Run("basic auth", func(t *testing.T) {
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
username, password := creds.BasicAuth()
|
||||||
|
if username != "user1" {
|
||||||
|
t.Fatalf("expected username 'user1', got '%s'", username)
|
||||||
|
}
|
||||||
|
if password != "pass1" {
|
||||||
|
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw credentials", func(t *testing.T) {
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
raw := creds.RawCredentials()
|
||||||
|
expected := "user1:pass1"
|
||||||
|
if raw != expected {
|
||||||
|
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty username", func(t *testing.T) {
|
||||||
|
creds := NewBasicCredentials("", "pass1")
|
||||||
|
username, password := creds.BasicAuth()
|
||||||
|
if username != "" {
|
||||||
|
t.Fatalf("expected empty username, got '%s'", username)
|
||||||
|
}
|
||||||
|
if password != "pass1" {
|
||||||
|
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReAuthCredentialsListener(t *testing.T) {
|
||||||
|
t.Run("successful re-auth", func(t *testing.T) {
|
||||||
|
var reAuthCalled bool
|
||||||
|
var onErrCalled bool
|
||||||
|
var receivedCreds Credentials
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
reAuthCalled = true
|
||||||
|
receivedCreds = creds
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
onErrCalled = true
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
listener.OnNext(creds)
|
||||||
|
|
||||||
|
if !reAuthCalled {
|
||||||
|
t.Fatal("expected reAuth to be called")
|
||||||
|
}
|
||||||
|
if onErrCalled {
|
||||||
|
t.Fatal("expected onErr not to be called")
|
||||||
|
}
|
||||||
|
if receivedCreds != creds {
|
||||||
|
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("re-auth error", func(t *testing.T) {
|
||||||
|
var reAuthCalled bool
|
||||||
|
var onErrCalled bool
|
||||||
|
var receivedErr error
|
||||||
|
expectedErr := errors.New("re-auth failed")
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
reAuthCalled = true
|
||||||
|
return expectedErr
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
onErrCalled = true
|
||||||
|
receivedErr = err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
listener.OnNext(creds)
|
||||||
|
|
||||||
|
if !reAuthCalled {
|
||||||
|
t.Fatal("expected reAuth to be called")
|
||||||
|
}
|
||||||
|
if !onErrCalled {
|
||||||
|
t.Fatal("expected onErr to be called")
|
||||||
|
}
|
||||||
|
if receivedErr != expectedErr {
|
||||||
|
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on error", func(t *testing.T) {
|
||||||
|
var onErrCalled bool
|
||||||
|
var receivedErr error
|
||||||
|
expectedErr := errors.New("provider error")
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
onErrCalled = true
|
||||||
|
receivedErr = err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
listener.OnError(expectedErr)
|
||||||
|
|
||||||
|
if !onErrCalled {
|
||||||
|
t.Fatal("expected onErr to be called")
|
||||||
|
}
|
||||||
|
if receivedErr != expectedErr {
|
||||||
|
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil callbacks", func(t *testing.T) {
|
||||||
|
listener := NewReAuthCredentialsListener(nil, nil)
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
listener.OnNext(NewBasicCredentials("user1", "pass1"))
|
||||||
|
listener.OnError(errors.New("test error"))
|
||||||
|
})
|
||||||
|
}
|
86
command_recorder_test.go
Normal file
86
command_recorder_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package redis_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// commandRecorder records the last N commands executed by a Redis client.
|
||||||
|
type commandRecorder struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
commands []string
|
||||||
|
maxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCommandRecorder creates a new command recorder with the specified maximum size.
|
||||||
|
func newCommandRecorder(maxSize int) *commandRecorder {
|
||||||
|
return &commandRecorder{
|
||||||
|
commands: make([]string, 0, maxSize),
|
||||||
|
maxSize: maxSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record adds a command to the recorder.
|
||||||
|
func (r *commandRecorder) Record(cmd string) {
|
||||||
|
cmd = strings.ToLower(cmd)
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
r.commands = append(r.commands, cmd)
|
||||||
|
if len(r.commands) > r.maxSize {
|
||||||
|
r.commands = r.commands[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastCommands returns a copy of the recorded commands.
|
||||||
|
func (r *commandRecorder) LastCommands() []string {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
return append([]string(nil), r.commands...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks if the recorder contains a specific command.
|
||||||
|
func (r *commandRecorder) Contains(cmd string) bool {
|
||||||
|
cmd = strings.ToLower(cmd)
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
for _, c := range r.commands {
|
||||||
|
if strings.Contains(c, cmd) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hook returns a Redis hook that records commands.
|
||||||
|
func (r *commandRecorder) Hook() redis.Hook {
|
||||||
|
return &commandHook{recorder: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
// commandHook implements the redis.Hook interface to record commands.
|
||||||
|
type commandHook struct {
|
||||||
|
recorder *commandRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||||
|
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||||
|
h.recorder.Record(cmd.String())
|
||||||
|
return next(ctx, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
h.recorder.Record(cmd.String())
|
||||||
|
}
|
||||||
|
return next(ctx, cmds)
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,8 @@ import (
|
|||||||
"github.com/redis/go-redis/v9/internal/rand"
|
"github.com/redis/go-redis/v9/internal/rand"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ParentHooksMixinKey struct{}
|
||||||
|
|
||||||
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
|
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
|
||||||
if retry < 0 {
|
if retry < 0 {
|
||||||
panic("not reached")
|
panic("not reached")
|
||||||
|
@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient(
|
|||||||
func (s *clusterScenario) Close() error {
|
func (s *clusterScenario) Close() error {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
for _, master := range s.masters() {
|
for _, master := range s.masters() {
|
||||||
|
if master == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
err := master.FlushAll(ctx).Err()
|
err := master.FlushAll(ctx).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("should CFCount", Label("cuckoo", "cfcount"), func() {
|
It("should CFCount", Label("cuckoo", "cfcount"), func() {
|
||||||
err := client.CFAdd(ctx, "testcf1", "item1").Err()
|
client.CFAdd(ctx, "testcf1", "item1")
|
||||||
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
|
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(cnt).To(BeEquivalentTo(int64(1)))
|
Expect(cnt).To(BeEquivalentTo(int64(1)))
|
||||||
@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
|||||||
NoCreate: true,
|
NoCreate: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
_, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
|
|
||||||
args = &redis.CFInsertOptions{
|
args = &redis.CFInsertOptions{
|
||||||
@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
|||||||
NoCreate: false,
|
NoCreate: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(len(result)).To(BeEquivalentTo(3))
|
Expect(len(result)).To(BeEquivalentTo(3))
|
||||||
})
|
})
|
||||||
|
79
redis.go
79
redis.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -308,8 +309,15 @@ func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err
|
|||||||
// we can get it from the *Conn and remove it from the clients pool.
|
// we can get it from the *Conn and remove it from the clients pool.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isBadConn(err, false, c.opt.Addr) {
|
if isBadConn(err, false, c.opt.Addr) {
|
||||||
poolCn, _ := cn.connPool.Get(ctx)
|
poolCn, getErr := cn.connPool.Get(ctx)
|
||||||
c.connPool.Remove(ctx, poolCn, err)
|
if getErr == nil {
|
||||||
|
c.connPool.Remove(ctx, poolCn, err)
|
||||||
|
} else {
|
||||||
|
// if we can't get the pool connection, we can only close the connection
|
||||||
|
if err := cn.Close(); err != nil {
|
||||||
|
log.Printf("failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -344,7 +352,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
var err error
|
var err error
|
||||||
cn.Inited = true
|
cn.Inited = true
|
||||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||||
conn := newConn(c.opt, connPool)
|
var parentHooks hooksMixin
|
||||||
|
pH := ctx.Value(internal.ParentHooksMixinKey{})
|
||||||
|
switch pH := pH.(type) {
|
||||||
|
case nil:
|
||||||
|
parentHooks = hooksMixin{}
|
||||||
|
case hooksMixin:
|
||||||
|
parentHooks = pH.clone()
|
||||||
|
case *hooksMixin:
|
||||||
|
parentHooks = (*pH).clone()
|
||||||
|
default:
|
||||||
|
parentHooks = hooksMixin{}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := newConn(c.opt, connPool, parentHooks)
|
||||||
|
|
||||||
protocol := c.opt.Protocol
|
protocol := c.opt.Protocol
|
||||||
// By default, use RESP3 in current version.
|
// By default, use RESP3 in current version.
|
||||||
@ -352,28 +373,30 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
protocol = 3
|
protocol = 3
|
||||||
}
|
}
|
||||||
|
|
||||||
var authenticated bool
|
username, password := "", ""
|
||||||
username, password := c.opt.Username, c.opt.Password
|
|
||||||
if c.opt.StreamingCredentialsProvider != nil {
|
if c.opt.StreamingCredentialsProvider != nil {
|
||||||
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||||
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
|
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||||
}
|
}
|
||||||
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
|
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
|
||||||
username, password = credentials.BasicAuth()
|
username, password = credentials.BasicAuth()
|
||||||
} else if c.opt.CredentialsProviderContext != nil {
|
} else if c.opt.CredentialsProviderContext != nil {
|
||||||
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||||
}
|
}
|
||||||
} else if c.opt.CredentialsProvider != nil {
|
} else if c.opt.CredentialsProvider != nil {
|
||||||
username, password = c.opt.CredentialsProvider()
|
username, password = c.opt.CredentialsProvider()
|
||||||
|
} else if c.opt.Username != "" || c.opt.Password != "" {
|
||||||
|
username, password = c.opt.Username, c.opt.Password
|
||||||
}
|
}
|
||||||
|
|
||||||
// for redis-server versions that do not support the HELLO command,
|
// for redis-server versions that do not support the HELLO command,
|
||||||
// RESP2 will continue to be used.
|
// RESP2 will continue to be used.
|
||||||
if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||||
authenticated = true
|
// Authentication successful with HELLO command
|
||||||
} else if !isRedisError(err) {
|
} else if !isRedisError(err) {
|
||||||
// When the server responds with the RESP protocol and the result is not a normal
|
// When the server responds with the RESP protocol and the result is not a normal
|
||||||
// execution result of the HELLO command, we consider it to be an indication that
|
// execution result of the HELLO command, we consider it to be an indication that
|
||||||
@ -382,15 +405,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
||||||
// with different error string results for unsupported commands, making it
|
// with different error string results for unsupported commands, making it
|
||||||
// difficult to rely on error strings to determine all results.
|
// difficult to rely on error strings to determine all results.
|
||||||
return err
|
return fmt.Errorf("failed to initialize connection: %w", err)
|
||||||
}
|
} else if password != "" {
|
||||||
|
// Try legacy AUTH command if HELLO failed
|
||||||
if !authenticated && password != "" {
|
|
||||||
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
|
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to authenticate: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||||
if c.opt.DB > 0 {
|
if c.opt.DB > 0 {
|
||||||
pipe.Select(ctx, c.opt.DB)
|
pipe.Select(ctx, c.opt.DB)
|
||||||
@ -407,7 +430,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||||
@ -422,13 +445,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
||||||
// out of order responses later on.
|
// out of order responses later on.
|
||||||
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
|
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
|
||||||
return err
|
return fmt.Errorf("failed to set client identity: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.opt.OnConnect != nil {
|
if c.opt.OnConnect != nil {
|
||||||
return c.opt.OnConnect(ctx, conn)
|
return c.opt.OnConnect(ctx, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -547,6 +571,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
|||||||
return c.opt.ReadTimeout
|
return c.opt.ReadTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// context returns the context for the current connection.
|
||||||
|
// If the context timeout is enabled, it returns the original context.
|
||||||
|
// Otherwise, it returns a new background context.
|
||||||
|
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||||
|
if c.opt.ContextTimeoutEnabled {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the client, releasing any open resources.
|
// Close closes the client, releasing any open resources.
|
||||||
//
|
//
|
||||||
// It is rare to Close a Client, as the Client is meant to be
|
// It is rare to Close a Client, as the Client is meant to be
|
||||||
@ -699,13 +733,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
|
||||||
if c.opt.ContextTimeoutEnabled {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
return context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
// Client is a Redis client representing a pool of zero or more underlying connections.
|
// Client is a Redis client representing a pool of zero or more underlying connections.
|
||||||
@ -752,7 +779,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Conn() *Conn {
|
func (c *Client) Conn() *Conn {
|
||||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
|
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.hooksMixin.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do create a Cmd from the args and processes the cmd.
|
// Do create a Cmd from the args and processes the cmd.
|
||||||
@ -763,6 +790,7 @@ func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Process(ctx context.Context, cmd Cmder) error {
|
func (c *Client) Process(ctx context.Context, cmd Cmder) error {
|
||||||
|
ctx = context.WithValue(ctx, internal.ParentHooksMixinKey{}, c.hooksMixin)
|
||||||
err := c.processHook(ctx, cmd)
|
err := c.processHook(ctx, cmd)
|
||||||
cmd.SetErr(err)
|
cmd.SetErr(err)
|
||||||
return err
|
return err
|
||||||
@ -888,7 +916,7 @@ type Conn struct {
|
|||||||
hooksMixin
|
hooksMixin
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
func newConn(opt *Options, connPool pool.Pooler, parentHooks hooksMixin) *Conn {
|
||||||
c := Conn{
|
c := Conn{
|
||||||
baseClient: baseClient{
|
baseClient: baseClient{
|
||||||
opt: opt,
|
opt: opt,
|
||||||
@ -898,6 +926,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
|||||||
|
|
||||||
c.cmdable = c.Process
|
c.cmdable = c.Process
|
||||||
c.statefulCmdable = c.Process
|
c.statefulCmdable = c.Process
|
||||||
|
c.hooksMixin = parentHooks
|
||||||
c.initHooks(hooks{
|
c.initHooks(hooks{
|
||||||
dial: c.baseClient.dial,
|
dial: c.baseClient.dial,
|
||||||
process: c.baseClient.process,
|
process: c.baseClient.process,
|
||||||
|
169
redis_test.go
169
redis_test.go
@ -14,6 +14,7 @@ import (
|
|||||||
. "github.com/bsm/gomega"
|
. "github.com/bsm/gomega"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/redis/go-redis/v9/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type redisHookError struct{}
|
type redisHookError struct{}
|
||||||
@ -727,3 +728,171 @@ var _ = Describe("Dialer connection timeouts", func() {
|
|||||||
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
|
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var _ = Describe("Credentials Provider Priority", func() {
|
||||||
|
var client *redis.Client
|
||||||
|
var opt *redis.Options
|
||||||
|
var recorder *commandRecorder
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
recorder = newCommandRecorder(10)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
if client != nil {
|
||||||
|
Expect(client.Close()).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use streaming provider when available", func() {
|
||||||
|
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
|
||||||
|
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||||
|
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
username, password := providerCreds.BasicAuth()
|
||||||
|
return username, password
|
||||||
|
},
|
||||||
|
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||||
|
username, password := ctxCreds.BasicAuth()
|
||||||
|
return username, password, nil
|
||||||
|
},
|
||||||
|
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||||
|
credentials: streamingCreds,
|
||||||
|
updates: make(chan auth.Credentials, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use context provider when streaming provider is not available", func() {
|
||||||
|
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||||
|
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
username, password := providerCreds.BasicAuth()
|
||||||
|
return username, password
|
||||||
|
},
|
||||||
|
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||||
|
username, password := ctxCreds.BasicAuth()
|
||||||
|
return username, password, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use regular provider when streaming and context providers are not available", func() {
|
||||||
|
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
username, password := providerCreds.BasicAuth()
|
||||||
|
return username, password
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use username/password fields when no providers are set", func() {
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use empty credentials when nothing is set", func() {
|
||||||
|
opt = &redis.Options{}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// no pass, ok
|
||||||
|
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH")).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle credential updates from streaming provider", func() {
|
||||||
|
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
|
||||||
|
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||||
|
credentials: initialCreds,
|
||||||
|
updates: make(chan auth.Credentials, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
|
||||||
|
|
||||||
|
// Update credentials
|
||||||
|
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
type mockStreamingProvider struct {
|
||||||
|
credentials auth.Credentials
|
||||||
|
err error
|
||||||
|
updates chan auth.Credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, nil, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial credentials
|
||||||
|
listener.OnNext(m.credentials)
|
||||||
|
|
||||||
|
// Start goroutine to handle updates
|
||||||
|
go func() {
|
||||||
|
for creds := range m.updates {
|
||||||
|
listener.OnNext(creds)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return m.credentials, func() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// this is just a mock:
|
||||||
|
// allow multiple closes from multiple listeners
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
close(m.updates)
|
||||||
|
return
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user