1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)

* wip

* update documentation

* add streamingcredentialsprovider in options

* fix: put back option in pool creation

* add package level comment

* Initial re authentication implementation

Introduces the StreamingCredentialsProvider as the CredentialsProvider
with the highest priority.

TODO: needs to be tested

* Change function type name

Change CancelProviderFunc to UnsubscribeFunc

* add tests

* fix race in tests

* fix example tests

* wip, hooks refactor

* fix build

* update README.md

* update wordlist

* update README.md

* refactor(auth): early returns in cred listener

* fix(doctest): simulate some delay

* feat(conn): add close hook on conn

* fix(tests): simulate start/stop in mock credentials provider

* fix(auth): don't double close the conn

* docs(README): mark streaming credentials provider as experimental

* fix(auth): streamline auth err proccess

* fix(auth): check err on close conn

* chore(entraid): use the repo under redis org
This commit is contained in:
Nedyalko Dyakov
2025-05-27 16:25:20 +03:00
committed by GitHub
parent 28a3c97409
commit 86d418f940
20 changed files with 1103 additions and 130 deletions

10
.github/wordlist.txt vendored
View File

@ -65,4 +65,12 @@ RedisGears
RedisTimeseries RedisTimeseries
RediSearch RediSearch
RawResult RawResult
RawVal RawVal
entra
EntraID
Entra
OAuth
Azure
StreamingCredentialsProvider
oauth
entraid

3
.gitignore vendored
View File

@ -7,4 +7,5 @@ testdata/*
redis8tests.sh redis8tests.sh
coverage.txt coverage.txt
**/coverage.txt **/coverage.txt
.vscode .vscode
tmp/*

121
README.md
View File

@ -68,6 +68,7 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
- Redis commands except QUIT and SYNC. - Redis commands except QUIT and SYNC.
- Automatic connection pooling. - Automatic connection pooling.
- [StreamingCredentialsProvider (e.g. entra id, oauth)](#1-streaming-credentials-provider-highest-priority) (experimental)
- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html). - [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html).
- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html). - [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html).
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html). - [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
@ -136,17 +137,121 @@ func ExampleClient() {
} }
``` ```
The above can be modified to specify the version of the RESP protocol by adding the `protocol` ### Authentication
option to the `Options` struct:
The Redis client supports multiple ways to provide authentication credentials, with a clear priority order. Here are the available options:
#### 1. Streaming Credentials Provider (Highest Priority) - Experimental feature
The streaming credentials provider allows for dynamic credential updates during the connection lifetime. This is particularly useful for managed identity services and token-based authentication.
```go ```go
rdb := redis.NewClient(&redis.Options{ type StreamingCredentialsProvider interface {
Addr: "localhost:6379", Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
Password: "", // no password set }
DB: 0, // use default DB
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
})
type CredentialsListener interface {
OnNext(credentials Credentials) // Called when credentials are updated
OnError(err error) // Called when an error occurs
}
type Credentials interface {
BasicAuth() (username string, password string)
RawCredentials() string
}
```
Example usage:
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
StreamingCredentialsProvider: &MyCredentialsProvider{},
})
```
**Note:** The streaming credentials provider can be used with [go-redis-entraid](https://github.com/redis/go-redis-entraid) to enable Entra ID (formerly Azure AD) authentication. This allows for seamless integration with Azure's managed identity services and token-based authentication.
Example with Entra ID:
```go
import (
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis-entraid"
)
// Create an Entra ID credentials provider
provider := entraid.NewDefaultAzureIdentityProvider()
// Configure Redis client with Entra ID authentication
rdb := redis.NewClient(&redis.Options{
Addr: "your-redis-server.redis.cache.windows.net:6380",
StreamingCredentialsProvider: provider,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
})
```
#### 2. Context-based Credentials Provider
The context-based provider allows credentials to be determined at the time of each operation, using the context.
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
// Return username, password, and any error
return "user", "pass", nil
},
})
```
#### 3. Regular Credentials Provider
A simple function-based provider that returns static credentials.
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
CredentialsProvider: func() (string, string) {
// Return username and password
return "user", "pass"
},
})
```
#### 4. Username/Password Fields (Lowest Priority)
The most basic way to provide credentials is through the `Username` and `Password` fields in the options.
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Username: "user",
Password: "pass",
})
```
#### Priority Order
The client will use credentials in the following priority order:
1. Streaming Credentials Provider (if set)
2. Context-based Credentials Provider (if set)
3. Regular Credentials Provider (if set)
4. Username/Password fields (if set)
If none of these are set, the client will attempt to connect without authentication.
### Protocol Version
The client supports both RESP2 and RESP3 protocols. You can specify the protocol version in the options:
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
})
``` ```
### Connecting via a redis url ### Connecting via a redis url

61
auth/auth.go Normal file
View File

@ -0,0 +1,61 @@
// Package auth package provides authentication-related interfaces and types.
// It also includes a basic implementation of credentials using username and password.
package auth
// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider.
// It is used to provide credentials for authentication.
// The CredentialsListener is used to receive updates when the credentials change.
type StreamingCredentialsProvider interface {
// Subscribe subscribes to the credentials provider for updates.
// It returns the current credentials, a cancel function to unsubscribe from the provider,
// and an error if any.
// TODO(ndyakov): Should we add context to the Subscribe method?
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
}
// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider.
// It is used to unsubscribe from the provider when the credentials are no longer needed.
type UnsubscribeFunc func() error
// CredentialsListener is an interface that defines the methods for a credentials listener.
// It is used to receive updates when the credentials change.
// The OnNext method is called when the credentials change.
// The OnError method is called when an error occurs while requesting the credentials.
type CredentialsListener interface {
OnNext(credentials Credentials)
OnError(err error)
}
// Credentials is an interface that defines the methods for credentials.
// It is used to provide the credentials for authentication.
type Credentials interface {
// BasicAuth returns the username and password for basic authentication.
BasicAuth() (username string, password string)
// RawCredentials returns the raw credentials as a string.
// This can be used to extract the username and password from the raw credentials or
// additional information if present in the token.
RawCredentials() string
}
type basicAuth struct {
username string
password string
}
// RawCredentials returns the raw credentials as a string.
func (b *basicAuth) RawCredentials() string {
return b.username + ":" + b.password
}
// BasicAuth returns the username and password for basic authentication.
func (b *basicAuth) BasicAuth() (username string, password string) {
return b.username, b.password
}
// NewBasicCredentials creates a new Credentials object from the given username and password.
func NewBasicCredentials(username, password string) Credentials {
return &basicAuth{
username: username,
password: password,
}
}

308
auth/auth_test.go Normal file
View File

@ -0,0 +1,308 @@
package auth
import (
"errors"
"sync"
"testing"
"time"
)
type mockStreamingProvider struct {
credentials Credentials
err error
updates chan Credentials
}
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
return &mockStreamingProvider{
credentials: initialCreds,
updates: make(chan Credentials, 10),
}
}
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}
// Send initial credentials
listener.OnNext(m.credentials)
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
listener.OnNext(creds)
}
}()
return m.credentials, func() error {
close(m.updates)
return nil
}, nil
}
func TestStreamingCredentialsProvider(t *testing.T) {
t.Run("successful subscription", func(t *testing.T) {
initialCreds := NewBasicCredentials("user1", "pass1")
provider := newMockStreamingProvider(initialCreds)
var receivedCreds []Credentials
var receivedErrors []error
var mu sync.Mutex
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
mu.Lock()
receivedCreds = append(receivedCreds, creds)
mu.Unlock()
return nil
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)
creds, cancel, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cancel == nil {
t.Fatal("expected cancel function to be non-nil")
}
if creds != initialCreds {
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
}
if len(receivedCreds) != 1 {
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
}
if receivedCreds[0] != initialCreds {
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
}
if len(receivedErrors) != 0 {
t.Fatalf("expected no errors, got %d", len(receivedErrors))
}
// Send an update
newCreds := NewBasicCredentials("user2", "pass2")
provider.updates <- newCreds
// Wait for update to be processed
time.Sleep(100 * time.Millisecond)
mu.Lock()
if len(receivedCreds) != 2 {
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
}
if receivedCreds[1] != newCreds {
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
}
mu.Unlock()
// Cancel subscription
if err := cancel(); err != nil {
t.Fatalf("unexpected error cancelling subscription: %v", err)
}
})
t.Run("subscription error", func(t *testing.T) {
provider := &mockStreamingProvider{
err: errors.New("subscription failed"),
}
var receivedCreds []Credentials
var receivedErrors []error
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
receivedCreds = append(receivedCreds, creds)
return nil
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)
creds, cancel, err := provider.Subscribe(listener)
if err == nil {
t.Fatal("expected error, got nil")
}
if cancel != nil {
t.Fatal("expected cancel function to be nil")
}
if creds != nil {
t.Fatalf("expected nil credentials, got %v", creds)
}
if len(receivedCreds) != 0 {
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
}
if len(receivedErrors) != 0 {
t.Fatalf("expected no errors, got %d", len(receivedErrors))
}
})
t.Run("re-auth error", func(t *testing.T) {
initialCreds := NewBasicCredentials("user1", "pass1")
provider := newMockStreamingProvider(initialCreds)
reauthErr := errors.New("re-auth failed")
var receivedErrors []error
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
return reauthErr
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)
creds, cancel, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cancel == nil {
t.Fatal("expected cancel function to be non-nil")
}
if creds != initialCreds {
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
}
if len(receivedErrors) != 1 {
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
}
if receivedErrors[0] != reauthErr {
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
}
if err := cancel(); err != nil {
t.Fatalf("unexpected error cancelling subscription: %v", err)
}
})
}
func TestBasicCredentials(t *testing.T) {
t.Run("basic auth", func(t *testing.T) {
creds := NewBasicCredentials("user1", "pass1")
username, password := creds.BasicAuth()
if username != "user1" {
t.Fatalf("expected username 'user1', got '%s'", username)
}
if password != "pass1" {
t.Fatalf("expected password 'pass1', got '%s'", password)
}
})
t.Run("raw credentials", func(t *testing.T) {
creds := NewBasicCredentials("user1", "pass1")
raw := creds.RawCredentials()
expected := "user1:pass1"
if raw != expected {
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
}
})
t.Run("empty username", func(t *testing.T) {
creds := NewBasicCredentials("", "pass1")
username, password := creds.BasicAuth()
if username != "" {
t.Fatalf("expected empty username, got '%s'", username)
}
if password != "pass1" {
t.Fatalf("expected password 'pass1', got '%s'", password)
}
})
}
func TestReAuthCredentialsListener(t *testing.T) {
t.Run("successful re-auth", func(t *testing.T) {
var reAuthCalled bool
var onErrCalled bool
var receivedCreds Credentials
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
reAuthCalled = true
receivedCreds = creds
return nil
},
func(err error) {
onErrCalled = true
},
)
creds := NewBasicCredentials("user1", "pass1")
listener.OnNext(creds)
if !reAuthCalled {
t.Fatal("expected reAuth to be called")
}
if onErrCalled {
t.Fatal("expected onErr not to be called")
}
if receivedCreds != creds {
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
}
})
t.Run("re-auth error", func(t *testing.T) {
var reAuthCalled bool
var onErrCalled bool
var receivedErr error
expectedErr := errors.New("re-auth failed")
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
reAuthCalled = true
return expectedErr
},
func(err error) {
onErrCalled = true
receivedErr = err
},
)
creds := NewBasicCredentials("user1", "pass1")
listener.OnNext(creds)
if !reAuthCalled {
t.Fatal("expected reAuth to be called")
}
if !onErrCalled {
t.Fatal("expected onErr to be called")
}
if receivedErr != expectedErr {
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
}
})
t.Run("on error", func(t *testing.T) {
var onErrCalled bool
var receivedErr error
expectedErr := errors.New("provider error")
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
return nil
},
func(err error) {
onErrCalled = true
receivedErr = err
},
)
listener.OnError(expectedErr)
if !onErrCalled {
t.Fatal("expected onErr to be called")
}
if receivedErr != expectedErr {
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
}
})
t.Run("nil callbacks", func(t *testing.T) {
listener := NewReAuthCredentialsListener(nil, nil)
// Should not panic
listener.OnNext(NewBasicCredentials("user1", "pass1"))
listener.OnError(errors.New("test error"))
})
}

View File

@ -0,0 +1,47 @@
package auth
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
// It is used to re-authenticate the credentials when they are updated.
// It contains:
// - reAuth: a function that takes the new credentials and returns an error if any.
// - onErr: a function that takes an error and handles it.
type ReAuthCredentialsListener struct {
reAuth func(credentials Credentials) error
onErr func(err error)
}
// OnNext is called when the credentials are updated.
// It calls the reAuth function with the new credentials.
// If the reAuth function returns an error, it calls the onErr function with the error.
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
if c.reAuth == nil {
return
}
err := c.reAuth(credentials)
if err != nil {
c.OnError(err)
}
}
// OnError is called when an error occurs.
// It can be called from both the credentials provider and the reAuth function.
func (c *ReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}
c.onErr(err)
}
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
// Implements the auth.CredentialsListener interface.
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
return &ReAuthCredentialsListener{
reAuth: reAuth,
onErr: onErr,
}
}
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)

86
command_recorder_test.go Normal file
View File

@ -0,0 +1,86 @@
package redis_test
import (
"context"
"strings"
"sync"
"github.com/redis/go-redis/v9"
)
// commandRecorder records the last N commands executed by a Redis client.
type commandRecorder struct {
mu sync.Mutex
commands []string
maxSize int
}
// newCommandRecorder creates a new command recorder with the specified maximum size.
func newCommandRecorder(maxSize int) *commandRecorder {
return &commandRecorder{
commands: make([]string, 0, maxSize),
maxSize: maxSize,
}
}
// Record adds a command to the recorder.
func (r *commandRecorder) Record(cmd string) {
cmd = strings.ToLower(cmd)
r.mu.Lock()
defer r.mu.Unlock()
r.commands = append(r.commands, cmd)
if len(r.commands) > r.maxSize {
r.commands = r.commands[1:]
}
}
// LastCommands returns a copy of the recorded commands.
func (r *commandRecorder) LastCommands() []string {
r.mu.Lock()
defer r.mu.Unlock()
return append([]string(nil), r.commands...)
}
// Contains checks if the recorder contains a specific command.
func (r *commandRecorder) Contains(cmd string) bool {
cmd = strings.ToLower(cmd)
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.commands {
if strings.Contains(c, cmd) {
return true
}
}
return false
}
// Hook returns a Redis hook that records commands.
func (r *commandRecorder) Hook() redis.Hook {
return &commandHook{recorder: r}
}
// commandHook implements the redis.Hook interface to record commands.
type commandHook struct {
recorder *commandRecorder
}
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
return next
}
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
h.recorder.Record(cmd.String())
return next(ctx, cmd)
}
}
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
for _, cmd := range cmds {
h.recorder.Record(cmd.String())
}
return next(ctx, cmds)
}
}

View File

@ -5,6 +5,7 @@ package example_commands_test
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@ -33,6 +34,7 @@ func ExampleClient_LPush_and_lrange() {
} }
fmt.Println(listSize) fmt.Println(listSize)
time.Sleep(10 * time.Millisecond) // Simulate some delay
value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result() value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result()
if err != nil { if err != nil {

View File

@ -23,38 +23,47 @@ func (redisHook) DialHook(hook redis.DialHook) redis.DialHook {
func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error { return func(ctx context.Context, cmd redis.Cmder) error {
fmt.Printf("starting processing: <%s>\n", cmd) fmt.Printf("starting processing: <%v>\n", cmd.Args())
err := hook(ctx, cmd) err := hook(ctx, cmd)
fmt.Printf("finished processing: <%s>\n", cmd) fmt.Printf("finished processing: <%v>\n", cmd.Args())
return err return err
} }
} }
func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error { return func(ctx context.Context, cmds []redis.Cmder) error {
fmt.Printf("pipeline starting processing: %v\n", cmds) names := make([]string, 0, len(cmds))
for _, cmd := range cmds {
names = append(names, fmt.Sprintf("%v", cmd.Args()))
}
fmt.Printf("pipeline starting processing: %v\n", names)
err := hook(ctx, cmds) err := hook(ctx, cmds)
fmt.Printf("pipeline finished processing: %v\n", cmds) fmt.Printf("pipeline finished processing: %v\n", names)
return err return err
} }
} }
func Example_instrumentation() { func Example_instrumentation() {
rdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
DisableIdentity: true,
}) })
rdb.AddHook(redisHook{}) rdb.AddHook(redisHook{})
rdb.Ping(ctx) rdb.Ping(ctx)
// Output: starting processing: <ping: > // Output:
// starting processing: <[ping]>
// dialing tcp :6379 // dialing tcp :6379
// finished dialing tcp :6379 // finished dialing tcp :6379
// finished processing: <ping: PONG> // starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// finished processing: <[ping]>
} }
func ExamplePipeline_instrumentation() { func ExamplePipeline_instrumentation() {
rdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
DisableIdentity: true,
}) })
rdb.AddHook(redisHook{}) rdb.AddHook(redisHook{})
@ -63,15 +72,19 @@ func ExamplePipeline_instrumentation() {
pipe.Ping(ctx) pipe.Ping(ctx)
return nil return nil
}) })
// Output: pipeline starting processing: [ping: ping: ] // Output:
// pipeline starting processing: [[ping] [ping]]
// dialing tcp :6379 // dialing tcp :6379
// finished dialing tcp :6379 // finished dialing tcp :6379
// pipeline finished processing: [ping: PONG ping: PONG] // starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// pipeline finished processing: [[ping] [ping]]
} }
func ExampleClient_Watch_instrumentation() { func ExampleClient_Watch_instrumentation() {
rdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
DisableIdentity: true,
}) })
rdb.AddHook(redisHook{}) rdb.AddHook(redisHook{})
@ -81,14 +94,16 @@ func ExampleClient_Watch_instrumentation() {
return nil return nil
}, "foo") }, "foo")
// Output: // Output:
// starting processing: <watch foo: > // starting processing: <[watch foo]>
// dialing tcp :6379 // dialing tcp :6379
// finished dialing tcp :6379 // finished dialing tcp :6379
// finished processing: <watch foo: OK> // starting processing: <[hello 3]>
// starting processing: <ping: > // finished processing: <[hello 3]>
// finished processing: <ping: PONG> // finished processing: <[watch foo]>
// starting processing: <ping: > // starting processing: <[ping]>
// finished processing: <ping: PONG> // finished processing: <[ping]>
// starting processing: <unwatch: > // starting processing: <[ping]>
// finished processing: <unwatch: OK> // finished processing: <[ping]>
// starting processing: <[unwatch]>
// finished processing: <[unwatch]>
} }

View File

@ -23,6 +23,8 @@ type Conn struct {
Inited bool Inited bool
pooled bool pooled bool
createdAt time.Time createdAt time.Time
onClose func() error
} }
func NewConn(netConn net.Conn) *Conn { func NewConn(netConn net.Conn) *Conn {
@ -46,6 +48,10 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix()) atomic.StoreInt64(&cn.usedAt, tm.Unix())
} }
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
func (cn *Conn) SetNetConn(netConn net.Conn) { func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn cn.netConn = netConn
cn.rd.Reset(netConn) cn.rd.Reset(netConn)
@ -95,6 +101,10 @@ func (cn *Conn) WithWriter(
} }
func (cn *Conn) Close() error { func (cn *Conn) Close() error {
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close() return cn.netConn.Close()
} }

View File

@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) {
}, },
NewClient: func(opt *Options) *Client { NewClient: func(opt *Options) *Client {
c := NewClient(opt) c := NewClient(opt)
c.baseClient.onClose = func() error { c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
closeCounter.increment(opt.Addr) closeCounter.increment(opt.Addr)
return nil return nil
} })
return c return c
}, },
}) })
@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) {
} }
createCounter.increment(opt.Addr) createCounter.increment(opt.Addr)
c := NewClient(opt) c := NewClient(opt)
c.baseClient.onClose = func() error { c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
closeCounter.increment(opt.Addr) closeCounter.increment(opt.Addr)
return nil return nil
} })
return c return c
}, },
}) })

View File

@ -13,6 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/pool"
) )
@ -29,10 +30,13 @@ type Limiter interface {
// Options keeps the settings to set up redis connection. // Options keeps the settings to set up redis connection.
type Options struct { type Options struct {
// The network type, either tcp or unix.
// Default is tcp. // Network type, either tcp or unix.
//
// default: is tcp.
Network string Network string
// host:port address.
// Addr is the address formated as host:port
Addr string Addr string
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
@ -46,17 +50,21 @@ type Options struct {
OnConnect func(ctx context.Context, cn *Conn) error OnConnect func(ctx context.Context, cn *Conn) error
// Protocol 2 or 3. Use the version to negotiate RESP version with redis-server. // Protocol 2 or 3. Use the version to negotiate RESP version with redis-server.
// Default is 3. //
// default: 3.
Protocol int Protocol int
// Use the specified Username to authenticate the current connection
// Username is used to authenticate the current connection
// with one of the connections defined in the ACL list when connecting // with one of the connections defined in the ACL list when connecting
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system. // to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
Username string Username string
// Optional password. Must match the password specified in the
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), // Password is an optional password. Must match the password specified in the
// `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower),
// or the User Password when connecting to a Redis 6.0 instance, or greater, // or the User Password when connecting to a Redis 6.0 instance, or greater,
// that is using the Redis ACL system. // that is using the Redis ACL system.
Password string Password string
// CredentialsProvider allows the username and password to be updated // CredentialsProvider allows the username and password to be updated
// before reconnecting. It should return the current username and password. // before reconnecting. It should return the current username and password.
CredentialsProvider func() (username string, password string) CredentialsProvider func() (username string, password string)
@ -67,85 +75,126 @@ type Options struct {
// There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
// Database to be selected after connecting to the server. // StreamingCredentialsProvider is used to retrieve the credentials
// for the connection from an external source. Those credentials may change
// during the connection lifetime. This is useful for managed identity
// scenarios where the credentials are retrieved from an external source.
//
// Currently, this is a placeholder for the future implementation.
StreamingCredentialsProvider auth.StreamingCredentialsProvider
// DB is the database to be selected after connecting to the server.
DB int DB int
// Maximum number of retries before giving up. // MaxRetries is the maximum number of retries before giving up.
// Default is 3 retries; -1 (not 0) disables retries. // -1 (not 0) disables retries.
//
// default: 3 retries
MaxRetries int MaxRetries int
// Minimum backoff between each retry.
// Default is 8 milliseconds; -1 disables backoff. // MinRetryBackoff is the minimum backoff between each retry.
// -1 disables backoff.
//
// default: 8 milliseconds
MinRetryBackoff time.Duration MinRetryBackoff time.Duration
// Maximum backoff between each retry.
// Default is 512 milliseconds; -1 disables backoff. // MaxRetryBackoff is the maximum backoff between each retry.
// -1 disables backoff.
// default: 512 milliseconds;
MaxRetryBackoff time.Duration MaxRetryBackoff time.Duration
// Dial timeout for establishing new connections. // DialTimeout for establishing new connections.
// Default is 5 seconds. //
// default: 5 seconds
DialTimeout time.Duration DialTimeout time.Duration
// Timeout for socket reads. If reached, commands will fail
// ReadTimeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values: // with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds). //
// - `-1` - no timeout (block indefinitely). // - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetReadDeadline calls completely. // - `-2` - disables SetReadDeadline calls completely.
//
// default: 3 seconds
ReadTimeout time.Duration ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// WriteTimeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking. Supported values: // with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds). //
// - `-1` - no timeout (block indefinitely). // - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetWriteDeadline calls completely. // - `-2` - disables SetWriteDeadline calls completely.
//
// default: 3 seconds
WriteTimeout time.Duration WriteTimeout time.Duration
// ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines. // ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines.
// See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts // See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts
ContextTimeoutEnabled bool ContextTimeoutEnabled bool
// Type of connection pool. // PoolFIFO type of connection pool.
// true for FIFO pool, false for LIFO pool. //
// - true for FIFO pool
// - false for LIFO pool.
//
// Note that FIFO has slightly higher overhead compared to LIFO, // Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size. // but it helps closing idle connections faster reducing the pool size.
PoolFIFO bool PoolFIFO bool
// Base number of socket connections.
// PoolSize is the base number of socket connections.
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. // Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize, // If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
// you can limit it through MaxActiveConns // you can limit it through MaxActiveConns
//
// default: 10 * runtime.GOMAXPROCS(0)
PoolSize int PoolSize int
// Amount of time client waits for connection if all connections
// PoolTimeout is the amount of time client waits for connection if all connections
// are busy before returning an error. // are busy before returning an error.
// Default is ReadTimeout + 1 second. //
// default: ReadTimeout + 1 second
PoolTimeout time.Duration PoolTimeout time.Duration
// Minimum number of idle connections which is useful when establishing
// new connection is slow. // MinIdleConns is the minimum number of idle connections which is useful when establishing
// Default is 0. the idle connections are not closed by default. // new connection is slow. The idle connections are not closed by default.
//
// default: 0
MinIdleConns int MinIdleConns int
// Maximum number of idle connections.
// Default is 0. the idle connections are not closed by default. // MaxIdleConns is the maximum number of idle connections.
// The idle connections are not closed by default.
//
// default: 0
MaxIdleConns int MaxIdleConns int
// Maximum number of connections allocated by the pool at a given time.
// MaxActiveConns is the maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool. // When zero, there is no limit on the number of connections in the pool.
// If the pool is full, the next call to Get() will block until a connection is released.
MaxActiveConns int MaxActiveConns int
// ConnMaxIdleTime is the maximum amount of time a connection may be idle. // ConnMaxIdleTime is the maximum amount of time a connection may be idle.
// Should be less than server's timeout. // Should be less than server's timeout.
// //
// Expired connections may be closed lazily before reuse. // Expired connections may be closed lazily before reuse.
// If d <= 0, connections are not closed due to a connection's idle time. // If d <= 0, connections are not closed due to a connection's idle time.
// -1 disables idle timeout check.
// //
// Default is 30 minutes. -1 disables idle timeout check. // default: 30 minutes
ConnMaxIdleTime time.Duration ConnMaxIdleTime time.Duration
// ConnMaxLifetime is the maximum amount of time a connection may be reused. // ConnMaxLifetime is the maximum amount of time a connection may be reused.
// //
// Expired connections may be closed lazily before reuse. // Expired connections may be closed lazily before reuse.
// If <= 0, connections are not closed due to a connection's age. // If <= 0, connections are not closed due to a connection's age.
// //
// Default is to not close idle connections. // default: 0
ConnMaxLifetime time.Duration ConnMaxLifetime time.Duration
// TLS Config to use. When set, TLS will be negotiated. // TLSConfig to use. When set, TLS will be negotiated.
TLSConfig *tls.Config TLSConfig *tls.Config
// Limiter interface used to implement circuit breaker or rate limiter. // Limiter interface used to implement circuit breaker or rate limiter.
Limiter Limiter Limiter Limiter
// Enables read only queries on slave/follower nodes. // readOnly enables read only queries on slave/follower nodes.
readOnly bool readOnly bool
// DisableIndentity - Disable set-lib on connect. // DisableIndentity - Disable set-lib on connect.
@ -161,9 +210,11 @@ type Options struct {
DisableIdentity bool DisableIdentity bool
// Add suffix to client name. Default is empty. // Add suffix to client name. Default is empty.
// IdentitySuffix - add suffix to client name.
IdentitySuffix string IdentitySuffix string
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3. // UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
UnstableResp3 bool UnstableResp3 bool
} }

View File

@ -14,6 +14,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hashtag" "github.com/redis/go-redis/v9/internal/hashtag"
"github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/pool"
@ -66,11 +67,12 @@ type ClusterOptions struct {
OnConnect func(ctx context.Context, cn *Conn) error OnConnect func(ctx context.Context, cn *Conn) error
Protocol int Protocol int
Username string Username string
Password string Password string
CredentialsProvider func() (username string, password string) CredentialsProvider func() (username string, password string)
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
StreamingCredentialsProvider auth.StreamingCredentialsProvider
MaxRetries int MaxRetries int
MinRetryBackoff time.Duration MinRetryBackoff time.Duration
@ -292,11 +294,12 @@ func (opt *ClusterOptions) clientOptions() *Options {
Dialer: opt.Dialer, Dialer: opt.Dialer,
OnConnect: opt.OnConnect, OnConnect: opt.OnConnect,
Protocol: opt.Protocol, Protocol: opt.Protocol,
Username: opt.Username, Username: opt.Username,
Password: opt.Password, Password: opt.Password,
CredentialsProvider: opt.CredentialsProvider, CredentialsProvider: opt.CredentialsProvider,
CredentialsProviderContext: opt.CredentialsProviderContext, CredentialsProviderContext: opt.CredentialsProviderContext,
StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
MaxRetries: opt.MaxRetries, MaxRetries: opt.MaxRetries,
MinRetryBackoff: opt.MinRetryBackoff, MinRetryBackoff: opt.MinRetryBackoff,

View File

@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient(
func (s *clusterScenario) Close() error { func (s *clusterScenario) Close() error {
ctx := context.TODO() ctx := context.TODO()
for _, master := range s.masters() { for _, master := range s.masters() {
if master == nil {
continue
}
err := master.FlushAll(ctx).Err() err := master.FlushAll(ctx).Err()
if err != nil { if err != nil {
return err return err

View File

@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
}) })
It("should CFCount", Label("cuckoo", "cfcount"), func() { It("should CFCount", Label("cuckoo", "cfcount"), func() {
err := client.CFAdd(ctx, "testcf1", "item1").Err() client.CFAdd(ctx, "testcf1", "item1")
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result() cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cnt).To(BeEquivalentTo(int64(1))) Expect(cnt).To(BeEquivalentTo(int64(1)))
@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
NoCreate: true, NoCreate: true,
} }
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() _, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
args = &redis.CFInsertOptions{ args = &redis.CFInsertOptions{
@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
NoCreate: false, NoCreate: false,
} }
result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(result)).To(BeEquivalentTo(3)) Expect(len(result)).To(BeEquivalentTo(3))
}) })

152
redis.go
View File

@ -9,6 +9,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/pool"
@ -203,6 +204,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
type baseClient struct { type baseClient struct {
opt *Options opt *Options
connPool pool.Pooler connPool pool.Pooler
hooksMixin
onClose func() error // hook called when client is closed onClose func() error // hook called when client is closed
} }
@ -282,30 +284,107 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil return cn, nil
} }
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
return auth.NewReAuthCredentialsListener(
c.reAuthConnection(poolCn),
c.onAuthenticationErr(poolCn),
)
}
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
return func(credentials auth.Credentials) error {
var err error
username, password := credentials.BasicAuth()
ctx := context.Background()
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
// hooksMixin are intentionally empty here
cn := newConn(c.opt, connPool, nil)
if username != "" {
err = cn.AuthACL(ctx, username, password).Err()
} else {
err = cn.Auth(ctx, password).Err()
}
return err
}
}
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
return func(err error) {
if err != nil {
if isBadConn(err, false, c.opt.Addr) {
// Close the connection to force a reconnection.
err := c.connPool.CloseConn(poolCn)
if err != nil {
internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err)
// try to close the network connection directly
// so that no resource is leaked
err := poolCn.Close()
if err != nil {
internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err)
}
}
}
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
}
}
}
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
onClose := c.onClose
return func() error {
var firstErr error
err := newOnClose()
// Even if we have an error we would like to execute the onClose hook
// if it exists. We will return the first error that occurred.
// This is to keep error handling consistent with the rest of the code.
if err != nil {
firstErr = err
}
if onClose != nil {
err = onClose()
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited { if cn.Inited {
return nil return nil
} }
cn.Inited = true
var err error var err error
username, password := c.opt.Username, c.opt.Password cn.Inited = true
if c.opt.CredentialsProviderContext != nil { connPool := pool.NewSingleConnPool(c.connPool, cn)
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { conn := newConn(c.opt, connPool, &c.hooksMixin)
return err
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(c.newReAuthCredentialsListener(cn))
if err != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
}
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
cn.SetOnClose(unsubscribeFromCredentialsProvider)
username, password = credentials.BasicAuth()
} else if c.opt.CredentialsProviderContext != nil {
username, password, err = c.opt.CredentialsProviderContext(ctx)
if err != nil {
return fmt.Errorf("failed to get credentials from context provider: %w", err)
} }
} else if c.opt.CredentialsProvider != nil { } else if c.opt.CredentialsProvider != nil {
username, password = c.opt.CredentialsProvider() username, password = c.opt.CredentialsProvider()
} else if c.opt.Username != "" || c.opt.Password != "" {
username, password = c.opt.Username, c.opt.Password
} }
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool)
var auth bool
// for redis-server versions that do not support the HELLO command, // for redis-server versions that do not support the HELLO command,
// RESP2 will continue to be used. // RESP2 will continue to be used.
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil { if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
auth = true // Authentication successful with HELLO command
} else if !isRedisError(err) { } else if !isRedisError(err) {
// When the server responds with the RESP protocol and the result is not a normal // When the server responds with the RESP protocol and the result is not a normal
// execution result of the HELLO command, we consider it to be an indication that // execution result of the HELLO command, we consider it to be an indication that
@ -315,17 +394,19 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// with different error string results for unsupported commands, making it // with different error string results for unsupported commands, making it
// difficult to rely on error strings to determine all results. // difficult to rely on error strings to determine all results.
return err return err
} else if password != "" {
// Try legacy AUTH command if HELLO failed
if username != "" {
err = conn.AuthACL(ctx, username, password).Err()
} else {
err = conn.Auth(ctx, password).Err()
}
if err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
}
} }
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
if !auth && password != "" {
if username != "" {
pipe.AuthACL(ctx, username, password)
} else {
pipe.Auth(ctx, password)
}
}
if c.opt.DB > 0 { if c.opt.DB > 0 {
pipe.Select(ctx, c.opt.DB) pipe.Select(ctx, c.opt.DB)
} }
@ -341,7 +422,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return nil return nil
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("failed to initialize connection options: %w", err)
} }
if !c.opt.DisableIdentity && !c.opt.DisableIndentity { if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
@ -363,6 +444,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if c.opt.OnConnect != nil { if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn) return c.opt.OnConnect(ctx, conn)
} }
return nil return nil
} }
@ -481,6 +563,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
return c.opt.ReadTimeout return c.opt.ReadTimeout
} }
// context returns the context for the current connection.
// If the context timeout is enabled, it returns the original context.
// Otherwise, it returns a new background context.
func (c *baseClient) context(ctx context.Context) context.Context {
if c.opt.ContextTimeoutEnabled {
return ctx
}
return context.Background()
}
// Close closes the client, releasing any open resources. // Close closes the client, releasing any open resources.
// //
// It is rare to Close a Client, as the Client is meant to be // It is rare to Close a Client, as the Client is meant to be
@ -633,13 +725,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
return nil return nil
} }
func (c *baseClient) context(ctx context.Context) context.Context {
if c.opt.ContextTimeoutEnabled {
return ctx
}
return context.Background()
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Client is a Redis client representing a pool of zero or more underlying connections. // Client is a Redis client representing a pool of zero or more underlying connections.
@ -650,7 +735,6 @@ func (c *baseClient) context(ctx context.Context) context.Context {
type Client struct { type Client struct {
*baseClient *baseClient
cmdable cmdable
hooksMixin
} }
// NewClient returns a client to the Redis Server specified by Options. // NewClient returns a client to the Redis Server specified by Options.
@ -689,7 +773,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
} }
func (c *Client) Conn() *Conn { func (c *Client) Conn() *Conn {
return newConn(c.opt, pool.NewStickyConnPool(c.connPool)) return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
} }
// Do create a Cmd from the args and processes the cmd. // Do create a Cmd from the args and processes the cmd.
@ -822,10 +906,12 @@ type Conn struct {
baseClient baseClient
cmdable cmdable
statefulCmdable statefulCmdable
hooksMixin
} }
func newConn(opt *Options, connPool pool.Pooler) *Conn { // newConn is a helper func to create a new Conn instance.
// the Conn instance is not thread-safe and should not be shared between goroutines.
// the parentHooks will be cloned, no need to clone before passing it.
func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
c := Conn{ c := Conn{
baseClient: baseClient{ baseClient: baseClient{
opt: opt, opt: opt,
@ -833,6 +919,10 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
}, },
} }
if parentHooks != nil {
c.hooksMixin = parentHooks.clone()
}
c.cmdable = c.Process c.cmdable = c.Process
c.statefulCmdable = c.Process c.statefulCmdable = c.Process
c.initHooks(hooks{ c.initHooks(hooks{

View File

@ -14,6 +14,7 @@ import (
. "github.com/bsm/gomega" . "github.com/bsm/gomega"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
) )
type redisHookError struct{} type redisHookError struct{}
@ -727,6 +728,174 @@ var _ = Describe("Dialer connection timeouts", func() {
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay)) Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
}) })
}) })
var _ = Describe("Credentials Provider Priority", func() {
var client *redis.Client
var opt *redis.Options
var recorder *commandRecorder
BeforeEach(func() {
recorder = newCommandRecorder(10)
})
AfterEach(func() {
if client != nil {
Expect(client.Close()).NotTo(HaveOccurred())
}
})
It("should use streaming provider when available", func() {
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: streamingCreds,
updates: make(chan auth.Credentials, 1),
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
})
It("should use context provider when streaming provider is not available", func() {
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
})
It("should use regular provider when streaming and context providers are not available", func() {
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
})
It("should use username/password fields when no providers are set", func() {
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
})
It("should use empty credentials when nothing is set", func() {
opt = &redis.Options{}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// no pass, ok
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
Expect(recorder.Contains("AUTH")).To(BeFalse())
})
It("should handle credential updates from streaming provider", func() {
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
updatesChan := make(chan auth.Credentials, 1)
opt = &redis.Options{
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: initialCreds,
updates: updatesChan,
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
// Update credentials
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
close(updatesChan)
})
})
type mockStreamingProvider struct {
credentials auth.Credentials
err error
updates chan auth.Credentials
}
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
m.credentials = creds
listener.OnNext(creds)
}
}()
return m.credentials, func() (err error) {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow multiple closes from multiple listeners
}
}()
return
}, nil
}
var _ = Describe("Client creation", func() { var _ = Describe("Client creation", func() {
Context("simple client with nil options", func() { Context("simple client with nil options", func() {
It("panics", func() { It("panics", func() {

View File

@ -357,13 +357,17 @@ var _ = Describe("Redis Ring", func() {
ring.AddHook(&hook{ ring.AddHook(&hook{
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error { return func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1)) // skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: ")) Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcessPipeline") stack = append(stack, "ring.BeforeProcessPipeline")
err := hook(ctx, cmds) err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(1)) Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: PONG")) Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcessPipeline") stack = append(stack, "ring.AfterProcessPipeline")
@ -376,13 +380,17 @@ var _ = Describe("Redis Ring", func() {
shard.AddHook(&hook{ shard.AddHook(&hook{
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error { return func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1)) // skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: ")) Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline") stack = append(stack, "shard.BeforeProcessPipeline")
err := hook(ctx, cmds) err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(1)) Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: PONG")) Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline") stack = append(stack, "shard.AfterProcessPipeline")
@ -416,14 +424,18 @@ var _ = Describe("Redis Ring", func() {
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error { return func(ctx context.Context, cmds []redis.Cmder) error {
defer GinkgoRecover() defer GinkgoRecover()
// skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(cmds).To(HaveLen(3)) Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: ")) Expect(cmds[1].String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcessPipeline") stack = append(stack, "ring.BeforeProcessPipeline")
err := hook(ctx, cmds) err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(3)) Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: PONG")) Expect(cmds[1].String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcessPipeline") stack = append(stack, "ring.AfterProcessPipeline")
@ -437,14 +449,18 @@ var _ = Describe("Redis Ring", func() {
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error { return func(ctx context.Context, cmds []redis.Cmder) error {
defer GinkgoRecover() defer GinkgoRecover()
// skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(cmds).To(HaveLen(3)) Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: ")) Expect(cmds[1].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline") stack = append(stack, "shard.BeforeProcessPipeline")
err := hook(ctx, cmds) err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(3)) Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: PONG")) Expect(cmds[1].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline") stack = append(stack, "shard.AfterProcessPipeline")

View File

@ -404,7 +404,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
connPool = newConnPool(opt, rdb.dialHook) connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool rdb.connPool = connPool
rdb.onClose = failover.Close rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock() failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) { failover.onFailover = func(ctx context.Context, addr string) {
@ -455,7 +455,6 @@ func masterReplicaDialer(
// SentinelClient is a client for a Redis Sentinel. // SentinelClient is a client for a Redis Sentinel.
type SentinelClient struct { type SentinelClient struct {
*baseClient *baseClient
hooksMixin
} }
func NewSentinelClient(opt *Options) *SentinelClient { func NewSentinelClient(opt *Options) *SentinelClient {

7
tx.go
View File

@ -19,16 +19,15 @@ type Tx struct {
baseClient baseClient
cmdable cmdable
statefulCmdable statefulCmdable
hooksMixin
} }
func (c *Client) newTx() *Tx { func (c *Client) newTx() *Tx {
tx := Tx{ tx := Tx{
baseClient: baseClient{ baseClient: baseClient{
opt: c.opt, opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool), connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
}, },
hooksMixin: c.hooksMixin.clone(),
} }
tx.init() tx.init()
return &tx return &tx