1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)

* wip

* update documentation

* add streamingcredentialsprovider in options

* fix: put back option in pool creation

* add package level comment

* Initial re authentication implementation

Introduces the StreamingCredentialsProvider as the CredentialsProvider
with the highest priority.

TODO: needs to be tested

* Change function type name

Change CancelProviderFunc to UnsubscribeFunc

* add tests

* fix race in tests

* fix example tests

* wip, hooks refactor

* fix build

* update README.md

* update wordlist

* update README.md

* refactor(auth): early returns in cred listener

* fix(doctest): simulate some delay

* feat(conn): add close hook on conn

* fix(tests): simulate start/stop in mock credentials provider

* fix(auth): don't double close the conn

* docs(README): mark streaming credentials provider as experimental

* fix(auth): streamline auth err proccess

* fix(auth): check err on close conn

* chore(entraid): use the repo under redis org
This commit is contained in:
Nedyalko Dyakov
2025-05-27 16:25:20 +03:00
committed by GitHub
parent 28a3c97409
commit 86d418f940
20 changed files with 1103 additions and 130 deletions

View File

@ -66,3 +66,11 @@ RedisTimeseries
RediSearch
RawResult
RawVal
entra
EntraID
Entra
OAuth
Azure
StreamingCredentialsProvider
oauth
entraid

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ redis8tests.sh
coverage.txt
**/coverage.txt
.vscode
tmp/*

111
README.md
View File

@ -68,6 +68,7 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
- Redis commands except QUIT and SYNC.
- Automatic connection pooling.
- [StreamingCredentialsProvider (e.g. entra id, oauth)](#1-streaming-credentials-provider-highest-priority) (experimental)
- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html).
- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html).
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
@ -136,8 +137,113 @@ func ExampleClient() {
}
```
The above can be modified to specify the version of the RESP protocol by adding the `protocol`
option to the `Options` struct:
### Authentication
The Redis client supports multiple ways to provide authentication credentials, with a clear priority order. Here are the available options:
#### 1. Streaming Credentials Provider (Highest Priority) - Experimental feature
The streaming credentials provider allows for dynamic credential updates during the connection lifetime. This is particularly useful for managed identity services and token-based authentication.
```go
type StreamingCredentialsProvider interface {
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
}
type CredentialsListener interface {
OnNext(credentials Credentials) // Called when credentials are updated
OnError(err error) // Called when an error occurs
}
type Credentials interface {
BasicAuth() (username string, password string)
RawCredentials() string
}
```
Example usage:
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
StreamingCredentialsProvider: &MyCredentialsProvider{},
})
```
**Note:** The streaming credentials provider can be used with [go-redis-entraid](https://github.com/redis/go-redis-entraid) to enable Entra ID (formerly Azure AD) authentication. This allows for seamless integration with Azure's managed identity services and token-based authentication.
Example with Entra ID:
```go
import (
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis-entraid"
)
// Create an Entra ID credentials provider
provider := entraid.NewDefaultAzureIdentityProvider()
// Configure Redis client with Entra ID authentication
rdb := redis.NewClient(&redis.Options{
Addr: "your-redis-server.redis.cache.windows.net:6380",
StreamingCredentialsProvider: provider,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
})
```
#### 2. Context-based Credentials Provider
The context-based provider allows credentials to be determined at the time of each operation, using the context.
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
// Return username, password, and any error
return "user", "pass", nil
},
})
```
#### 3. Regular Credentials Provider
A simple function-based provider that returns static credentials.
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
CredentialsProvider: func() (string, string) {
// Return username and password
return "user", "pass"
},
})
```
#### 4. Username/Password Fields (Lowest Priority)
The most basic way to provide credentials is through the `Username` and `Password` fields in the options.
```go
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Username: "user",
Password: "pass",
})
```
#### Priority Order
The client will use credentials in the following priority order:
1. Streaming Credentials Provider (if set)
2. Context-based Credentials Provider (if set)
3. Regular Credentials Provider (if set)
4. Username/Password fields (if set)
If none of these are set, the client will attempt to connect without authentication.
### Protocol Version
The client supports both RESP2 and RESP3 protocols. You can specify the protocol version in the options:
```go
rdb := redis.NewClient(&redis.Options{
@ -146,7 +252,6 @@ option to the `Options` struct:
DB: 0, // use default DB
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
})
```
### Connecting via a redis url

61
auth/auth.go Normal file
View File

@ -0,0 +1,61 @@
// Package auth package provides authentication-related interfaces and types.
// It also includes a basic implementation of credentials using username and password.
package auth
// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider.
// It is used to provide credentials for authentication.
// The CredentialsListener is used to receive updates when the credentials change.
type StreamingCredentialsProvider interface {
// Subscribe subscribes to the credentials provider for updates.
// It returns the current credentials, a cancel function to unsubscribe from the provider,
// and an error if any.
// TODO(ndyakov): Should we add context to the Subscribe method?
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
}
// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider.
// It is used to unsubscribe from the provider when the credentials are no longer needed.
type UnsubscribeFunc func() error
// CredentialsListener is an interface that defines the methods for a credentials listener.
// It is used to receive updates when the credentials change.
// The OnNext method is called when the credentials change.
// The OnError method is called when an error occurs while requesting the credentials.
type CredentialsListener interface {
OnNext(credentials Credentials)
OnError(err error)
}
// Credentials is an interface that defines the methods for credentials.
// It is used to provide the credentials for authentication.
type Credentials interface {
// BasicAuth returns the username and password for basic authentication.
BasicAuth() (username string, password string)
// RawCredentials returns the raw credentials as a string.
// This can be used to extract the username and password from the raw credentials or
// additional information if present in the token.
RawCredentials() string
}
type basicAuth struct {
username string
password string
}
// RawCredentials returns the raw credentials as a string.
func (b *basicAuth) RawCredentials() string {
return b.username + ":" + b.password
}
// BasicAuth returns the username and password for basic authentication.
func (b *basicAuth) BasicAuth() (username string, password string) {
return b.username, b.password
}
// NewBasicCredentials creates a new Credentials object from the given username and password.
func NewBasicCredentials(username, password string) Credentials {
return &basicAuth{
username: username,
password: password,
}
}

308
auth/auth_test.go Normal file
View File

@ -0,0 +1,308 @@
package auth
import (
"errors"
"sync"
"testing"
"time"
)
type mockStreamingProvider struct {
credentials Credentials
err error
updates chan Credentials
}
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
return &mockStreamingProvider{
credentials: initialCreds,
updates: make(chan Credentials, 10),
}
}
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}
// Send initial credentials
listener.OnNext(m.credentials)
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
listener.OnNext(creds)
}
}()
return m.credentials, func() error {
close(m.updates)
return nil
}, nil
}
func TestStreamingCredentialsProvider(t *testing.T) {
t.Run("successful subscription", func(t *testing.T) {
initialCreds := NewBasicCredentials("user1", "pass1")
provider := newMockStreamingProvider(initialCreds)
var receivedCreds []Credentials
var receivedErrors []error
var mu sync.Mutex
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
mu.Lock()
receivedCreds = append(receivedCreds, creds)
mu.Unlock()
return nil
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)
creds, cancel, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cancel == nil {
t.Fatal("expected cancel function to be non-nil")
}
if creds != initialCreds {
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
}
if len(receivedCreds) != 1 {
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
}
if receivedCreds[0] != initialCreds {
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
}
if len(receivedErrors) != 0 {
t.Fatalf("expected no errors, got %d", len(receivedErrors))
}
// Send an update
newCreds := NewBasicCredentials("user2", "pass2")
provider.updates <- newCreds
// Wait for update to be processed
time.Sleep(100 * time.Millisecond)
mu.Lock()
if len(receivedCreds) != 2 {
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
}
if receivedCreds[1] != newCreds {
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
}
mu.Unlock()
// Cancel subscription
if err := cancel(); err != nil {
t.Fatalf("unexpected error cancelling subscription: %v", err)
}
})
t.Run("subscription error", func(t *testing.T) {
provider := &mockStreamingProvider{
err: errors.New("subscription failed"),
}
var receivedCreds []Credentials
var receivedErrors []error
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
receivedCreds = append(receivedCreds, creds)
return nil
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)
creds, cancel, err := provider.Subscribe(listener)
if err == nil {
t.Fatal("expected error, got nil")
}
if cancel != nil {
t.Fatal("expected cancel function to be nil")
}
if creds != nil {
t.Fatalf("expected nil credentials, got %v", creds)
}
if len(receivedCreds) != 0 {
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
}
if len(receivedErrors) != 0 {
t.Fatalf("expected no errors, got %d", len(receivedErrors))
}
})
t.Run("re-auth error", func(t *testing.T) {
initialCreds := NewBasicCredentials("user1", "pass1")
provider := newMockStreamingProvider(initialCreds)
reauthErr := errors.New("re-auth failed")
var receivedErrors []error
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
return reauthErr
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)
creds, cancel, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cancel == nil {
t.Fatal("expected cancel function to be non-nil")
}
if creds != initialCreds {
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
}
if len(receivedErrors) != 1 {
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
}
if receivedErrors[0] != reauthErr {
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
}
if err := cancel(); err != nil {
t.Fatalf("unexpected error cancelling subscription: %v", err)
}
})
}
func TestBasicCredentials(t *testing.T) {
t.Run("basic auth", func(t *testing.T) {
creds := NewBasicCredentials("user1", "pass1")
username, password := creds.BasicAuth()
if username != "user1" {
t.Fatalf("expected username 'user1', got '%s'", username)
}
if password != "pass1" {
t.Fatalf("expected password 'pass1', got '%s'", password)
}
})
t.Run("raw credentials", func(t *testing.T) {
creds := NewBasicCredentials("user1", "pass1")
raw := creds.RawCredentials()
expected := "user1:pass1"
if raw != expected {
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
}
})
t.Run("empty username", func(t *testing.T) {
creds := NewBasicCredentials("", "pass1")
username, password := creds.BasicAuth()
if username != "" {
t.Fatalf("expected empty username, got '%s'", username)
}
if password != "pass1" {
t.Fatalf("expected password 'pass1', got '%s'", password)
}
})
}
func TestReAuthCredentialsListener(t *testing.T) {
t.Run("successful re-auth", func(t *testing.T) {
var reAuthCalled bool
var onErrCalled bool
var receivedCreds Credentials
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
reAuthCalled = true
receivedCreds = creds
return nil
},
func(err error) {
onErrCalled = true
},
)
creds := NewBasicCredentials("user1", "pass1")
listener.OnNext(creds)
if !reAuthCalled {
t.Fatal("expected reAuth to be called")
}
if onErrCalled {
t.Fatal("expected onErr not to be called")
}
if receivedCreds != creds {
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
}
})
t.Run("re-auth error", func(t *testing.T) {
var reAuthCalled bool
var onErrCalled bool
var receivedErr error
expectedErr := errors.New("re-auth failed")
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
reAuthCalled = true
return expectedErr
},
func(err error) {
onErrCalled = true
receivedErr = err
},
)
creds := NewBasicCredentials("user1", "pass1")
listener.OnNext(creds)
if !reAuthCalled {
t.Fatal("expected reAuth to be called")
}
if !onErrCalled {
t.Fatal("expected onErr to be called")
}
if receivedErr != expectedErr {
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
}
})
t.Run("on error", func(t *testing.T) {
var onErrCalled bool
var receivedErr error
expectedErr := errors.New("provider error")
listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
return nil
},
func(err error) {
onErrCalled = true
receivedErr = err
},
)
listener.OnError(expectedErr)
if !onErrCalled {
t.Fatal("expected onErr to be called")
}
if receivedErr != expectedErr {
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
}
})
t.Run("nil callbacks", func(t *testing.T) {
listener := NewReAuthCredentialsListener(nil, nil)
// Should not panic
listener.OnNext(NewBasicCredentials("user1", "pass1"))
listener.OnError(errors.New("test error"))
})
}

View File

@ -0,0 +1,47 @@
package auth
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
// It is used to re-authenticate the credentials when they are updated.
// It contains:
// - reAuth: a function that takes the new credentials and returns an error if any.
// - onErr: a function that takes an error and handles it.
type ReAuthCredentialsListener struct {
reAuth func(credentials Credentials) error
onErr func(err error)
}
// OnNext is called when the credentials are updated.
// It calls the reAuth function with the new credentials.
// If the reAuth function returns an error, it calls the onErr function with the error.
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
if c.reAuth == nil {
return
}
err := c.reAuth(credentials)
if err != nil {
c.OnError(err)
}
}
// OnError is called when an error occurs.
// It can be called from both the credentials provider and the reAuth function.
func (c *ReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}
c.onErr(err)
}
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
// Implements the auth.CredentialsListener interface.
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
return &ReAuthCredentialsListener{
reAuth: reAuth,
onErr: onErr,
}
}
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)

86
command_recorder_test.go Normal file
View File

@ -0,0 +1,86 @@
package redis_test
import (
"context"
"strings"
"sync"
"github.com/redis/go-redis/v9"
)
// commandRecorder records the last N commands executed by a Redis client.
type commandRecorder struct {
mu sync.Mutex
commands []string
maxSize int
}
// newCommandRecorder creates a new command recorder with the specified maximum size.
func newCommandRecorder(maxSize int) *commandRecorder {
return &commandRecorder{
commands: make([]string, 0, maxSize),
maxSize: maxSize,
}
}
// Record adds a command to the recorder.
func (r *commandRecorder) Record(cmd string) {
cmd = strings.ToLower(cmd)
r.mu.Lock()
defer r.mu.Unlock()
r.commands = append(r.commands, cmd)
if len(r.commands) > r.maxSize {
r.commands = r.commands[1:]
}
}
// LastCommands returns a copy of the recorded commands.
func (r *commandRecorder) LastCommands() []string {
r.mu.Lock()
defer r.mu.Unlock()
return append([]string(nil), r.commands...)
}
// Contains checks if the recorder contains a specific command.
func (r *commandRecorder) Contains(cmd string) bool {
cmd = strings.ToLower(cmd)
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.commands {
if strings.Contains(c, cmd) {
return true
}
}
return false
}
// Hook returns a Redis hook that records commands.
func (r *commandRecorder) Hook() redis.Hook {
return &commandHook{recorder: r}
}
// commandHook implements the redis.Hook interface to record commands.
type commandHook struct {
recorder *commandRecorder
}
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
return next
}
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
h.recorder.Record(cmd.String())
return next(ctx, cmd)
}
}
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
for _, cmd := range cmds {
h.recorder.Record(cmd.String())
}
return next(ctx, cmds)
}
}

View File

@ -5,6 +5,7 @@ package example_commands_test
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
@ -33,6 +34,7 @@ func ExampleClient_LPush_and_lrange() {
}
fmt.Println(listSize)
time.Sleep(10 * time.Millisecond) // Simulate some delay
value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result()
if err != nil {

View File

@ -23,18 +23,22 @@ func (redisHook) DialHook(hook redis.DialHook) redis.DialHook {
func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
fmt.Printf("starting processing: <%s>\n", cmd)
fmt.Printf("starting processing: <%v>\n", cmd.Args())
err := hook(ctx, cmd)
fmt.Printf("finished processing: <%s>\n", cmd)
fmt.Printf("finished processing: <%v>\n", cmd.Args())
return err
}
}
func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
fmt.Printf("pipeline starting processing: %v\n", cmds)
names := make([]string, 0, len(cmds))
for _, cmd := range cmds {
names = append(names, fmt.Sprintf("%v", cmd.Args()))
}
fmt.Printf("pipeline starting processing: %v\n", names)
err := hook(ctx, cmds)
fmt.Printf("pipeline finished processing: %v\n", cmds)
fmt.Printf("pipeline finished processing: %v\n", names)
return err
}
}
@ -42,19 +46,24 @@ func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.Proce
func Example_instrumentation() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
DisableIdentity: true,
})
rdb.AddHook(redisHook{})
rdb.Ping(ctx)
// Output: starting processing: <ping: >
// Output:
// starting processing: <[ping]>
// dialing tcp :6379
// finished dialing tcp :6379
// finished processing: <ping: PONG>
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// finished processing: <[ping]>
}
func ExamplePipeline_instrumentation() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
DisableIdentity: true,
})
rdb.AddHook(redisHook{})
@ -63,15 +72,19 @@ func ExamplePipeline_instrumentation() {
pipe.Ping(ctx)
return nil
})
// Output: pipeline starting processing: [ping: ping: ]
// Output:
// pipeline starting processing: [[ping] [ping]]
// dialing tcp :6379
// finished dialing tcp :6379
// pipeline finished processing: [ping: PONG ping: PONG]
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// pipeline finished processing: [[ping] [ping]]
}
func ExampleClient_Watch_instrumentation() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
DisableIdentity: true,
})
rdb.AddHook(redisHook{})
@ -81,14 +94,16 @@ func ExampleClient_Watch_instrumentation() {
return nil
}, "foo")
// Output:
// starting processing: <watch foo: >
// starting processing: <[watch foo]>
// dialing tcp :6379
// finished dialing tcp :6379
// finished processing: <watch foo: OK>
// starting processing: <ping: >
// finished processing: <ping: PONG>
// starting processing: <ping: >
// finished processing: <ping: PONG>
// starting processing: <unwatch: >
// finished processing: <unwatch: OK>
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// finished processing: <[watch foo]>
// starting processing: <[ping]>
// finished processing: <[ping]>
// starting processing: <[ping]>
// finished processing: <[ping]>
// starting processing: <[unwatch]>
// finished processing: <[unwatch]>
}

View File

@ -23,6 +23,8 @@ type Conn struct {
Inited bool
pooled bool
createdAt time.Time
onClose func() error
}
func NewConn(netConn net.Conn) *Conn {
@ -46,6 +48,10 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
cn.rd.Reset(netConn)
@ -95,6 +101,10 @@ func (cn *Conn) WithWriter(
}
func (cn *Conn) Close() error {
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
}

View File

@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) {
},
NewClient: func(opt *Options) *Client {
c := NewClient(opt)
c.baseClient.onClose = func() error {
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
closeCounter.increment(opt.Addr)
return nil
}
})
return c
},
})
@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) {
}
createCounter.increment(opt.Addr)
c := NewClient(opt)
c.baseClient.onClose = func() error {
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
closeCounter.increment(opt.Addr)
return nil
}
})
return c
},
})

View File

@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)
@ -29,10 +30,13 @@ type Limiter interface {
// Options keeps the settings to set up redis connection.
type Options struct {
// The network type, either tcp or unix.
// Default is tcp.
// Network type, either tcp or unix.
//
// default: is tcp.
Network string
// host:port address.
// Addr is the address formated as host:port
Addr string
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
@ -46,17 +50,21 @@ type Options struct {
OnConnect func(ctx context.Context, cn *Conn) error
// Protocol 2 or 3. Use the version to negotiate RESP version with redis-server.
// Default is 3.
//
// default: 3.
Protocol int
// Use the specified Username to authenticate the current connection
// Username is used to authenticate the current connection
// with one of the connections defined in the ACL list when connecting
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
Username string
// Optional password. Must match the password specified in the
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
// Password is an optional password. Must match the password specified in the
// `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower),
// or the User Password when connecting to a Redis 6.0 instance, or greater,
// that is using the Redis ACL system.
Password string
// CredentialsProvider allows the username and password to be updated
// before reconnecting. It should return the current username and password.
CredentialsProvider func() (username string, password string)
@ -67,76 +75,117 @@ type Options struct {
// There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
// Database to be selected after connecting to the server.
// StreamingCredentialsProvider is used to retrieve the credentials
// for the connection from an external source. Those credentials may change
// during the connection lifetime. This is useful for managed identity
// scenarios where the credentials are retrieved from an external source.
//
// Currently, this is a placeholder for the future implementation.
StreamingCredentialsProvider auth.StreamingCredentialsProvider
// DB is the database to be selected after connecting to the server.
DB int
// Maximum number of retries before giving up.
// Default is 3 retries; -1 (not 0) disables retries.
// MaxRetries is the maximum number of retries before giving up.
// -1 (not 0) disables retries.
//
// default: 3 retries
MaxRetries int
// Minimum backoff between each retry.
// Default is 8 milliseconds; -1 disables backoff.
// MinRetryBackoff is the minimum backoff between each retry.
// -1 disables backoff.
//
// default: 8 milliseconds
MinRetryBackoff time.Duration
// Maximum backoff between each retry.
// Default is 512 milliseconds; -1 disables backoff.
// MaxRetryBackoff is the maximum backoff between each retry.
// -1 disables backoff.
// default: 512 milliseconds;
MaxRetryBackoff time.Duration
// Dial timeout for establishing new connections.
// Default is 5 seconds.
// DialTimeout for establishing new connections.
//
// default: 5 seconds
DialTimeout time.Duration
// Timeout for socket reads. If reached, commands will fail
// ReadTimeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds).
//
// - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetReadDeadline calls completely.
//
// default: 3 seconds
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// WriteTimeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds).
//
// - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetWriteDeadline calls completely.
//
// default: 3 seconds
WriteTimeout time.Duration
// ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines.
// See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts
ContextTimeoutEnabled bool
// Type of connection pool.
// true for FIFO pool, false for LIFO pool.
// PoolFIFO type of connection pool.
//
// - true for FIFO pool
// - false for LIFO pool.
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
PoolFIFO bool
// Base number of socket connections.
// PoolSize is the base number of socket connections.
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
// you can limit it through MaxActiveConns
//
// default: 10 * runtime.GOMAXPROCS(0)
PoolSize int
// Amount of time client waits for connection if all connections
// PoolTimeout is the amount of time client waits for connection if all connections
// are busy before returning an error.
// Default is ReadTimeout + 1 second.
//
// default: ReadTimeout + 1 second
PoolTimeout time.Duration
// Minimum number of idle connections which is useful when establishing
// new connection is slow.
// Default is 0. the idle connections are not closed by default.
// MinIdleConns is the minimum number of idle connections which is useful when establishing
// new connection is slow. The idle connections are not closed by default.
//
// default: 0
MinIdleConns int
// Maximum number of idle connections.
// Default is 0. the idle connections are not closed by default.
// MaxIdleConns is the maximum number of idle connections.
// The idle connections are not closed by default.
//
// default: 0
MaxIdleConns int
// Maximum number of connections allocated by the pool at a given time.
// MaxActiveConns is the maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool.
// If the pool is full, the next call to Get() will block until a connection is released.
MaxActiveConns int
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
// Should be less than server's timeout.
//
// Expired connections may be closed lazily before reuse.
// If d <= 0, connections are not closed due to a connection's idle time.
// -1 disables idle timeout check.
//
// Default is 30 minutes. -1 disables idle timeout check.
// default: 30 minutes
ConnMaxIdleTime time.Duration
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
//
// Expired connections may be closed lazily before reuse.
// If <= 0, connections are not closed due to a connection's age.
//
// Default is to not close idle connections.
// default: 0
ConnMaxLifetime time.Duration
// TLSConfig to use. When set, TLS will be negotiated.
@ -145,7 +194,7 @@ type Options struct {
// Limiter interface used to implement circuit breaker or rate limiter.
Limiter Limiter
// Enables read only queries on slave/follower nodes.
// readOnly enables read only queries on slave/follower nodes.
readOnly bool
// DisableIndentity - Disable set-lib on connect.
@ -161,9 +210,11 @@ type Options struct {
DisableIdentity bool
// Add suffix to client name. Default is empty.
// IdentitySuffix - add suffix to client name.
IdentitySuffix string
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
UnstableResp3 bool
}

View File

@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hashtag"
"github.com/redis/go-redis/v9/internal/pool"
@ -71,6 +72,7 @@ type ClusterOptions struct {
Password string
CredentialsProvider func() (username string, password string)
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
StreamingCredentialsProvider auth.StreamingCredentialsProvider
MaxRetries int
MinRetryBackoff time.Duration
@ -297,6 +299,7 @@ func (opt *ClusterOptions) clientOptions() *Options {
Password: opt.Password,
CredentialsProvider: opt.CredentialsProvider,
CredentialsProviderContext: opt.CredentialsProviderContext,
StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
MaxRetries: opt.MaxRetries,
MinRetryBackoff: opt.MinRetryBackoff,

View File

@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient(
func (s *clusterScenario) Close() error {
ctx := context.TODO()
for _, master := range s.masters() {
if master == nil {
continue
}
err := master.FlushAll(ctx).Err()
if err != nil {
return err

View File

@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
})
It("should CFCount", Label("cuckoo", "cfcount"), func() {
err := client.CFAdd(ctx, "testcf1", "item1").Err()
client.CFAdd(ctx, "testcf1", "item1")
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
Expect(err).NotTo(HaveOccurred())
Expect(cnt).To(BeEquivalentTo(int64(1)))
@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
NoCreate: true,
}
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
_, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
Expect(err).To(HaveOccurred())
args = &redis.CFInsertOptions{
@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
NoCreate: false,
}
result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(result)).To(BeEquivalentTo(3))
})

150
redis.go
View File

@ -9,6 +9,7 @@ import (
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
@ -203,6 +204,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
type baseClient struct {
opt *Options
connPool pool.Pooler
hooksMixin
onClose func() error // hook called when client is closed
}
@ -282,30 +284,107 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
return auth.NewReAuthCredentialsListener(
c.reAuthConnection(poolCn),
c.onAuthenticationErr(poolCn),
)
}
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
return func(credentials auth.Credentials) error {
var err error
username, password := credentials.BasicAuth()
ctx := context.Background()
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
// hooksMixin are intentionally empty here
cn := newConn(c.opt, connPool, nil)
if username != "" {
err = cn.AuthACL(ctx, username, password).Err()
} else {
err = cn.Auth(ctx, password).Err()
}
return err
}
}
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
return func(err error) {
if err != nil {
if isBadConn(err, false, c.opt.Addr) {
// Close the connection to force a reconnection.
err := c.connPool.CloseConn(poolCn)
if err != nil {
internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err)
// try to close the network connection directly
// so that no resource is leaked
err := poolCn.Close()
if err != nil {
internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err)
}
}
}
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
}
}
}
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
onClose := c.onClose
return func() error {
var firstErr error
err := newOnClose()
// Even if we have an error we would like to execute the onClose hook
// if it exists. We will return the first error that occurred.
// This is to keep error handling consistent with the rest of the code.
if err != nil {
firstErr = err
}
if onClose != nil {
err = onClose()
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
var err error
username, password := c.opt.Username, c.opt.Password
if c.opt.CredentialsProviderContext != nil {
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
return err
cn.Inited = true
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin)
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(c.newReAuthCredentialsListener(cn))
if err != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
}
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
cn.SetOnClose(unsubscribeFromCredentialsProvider)
username, password = credentials.BasicAuth()
} else if c.opt.CredentialsProviderContext != nil {
username, password, err = c.opt.CredentialsProviderContext(ctx)
if err != nil {
return fmt.Errorf("failed to get credentials from context provider: %w", err)
}
} else if c.opt.CredentialsProvider != nil {
username, password = c.opt.CredentialsProvider()
} else if c.opt.Username != "" || c.opt.Password != "" {
username, password = c.opt.Username, c.opt.Password
}
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool)
var auth bool
// for redis-server versions that do not support the HELLO command,
// RESP2 will continue to be used.
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
auth = true
// Authentication successful with HELLO command
} else if !isRedisError(err) {
// When the server responds with the RESP protocol and the result is not a normal
// execution result of the HELLO command, we consider it to be an indication that
@ -315,17 +394,19 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// with different error string results for unsupported commands, making it
// difficult to rely on error strings to determine all results.
return err
} else if password != "" {
// Try legacy AUTH command if HELLO failed
if username != "" {
err = conn.AuthACL(ctx, username, password).Err()
} else {
err = conn.Auth(ctx, password).Err()
}
if err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
}
}
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
if !auth && password != "" {
if username != "" {
pipe.AuthACL(ctx, username, password)
} else {
pipe.Auth(ctx, password)
}
}
if c.opt.DB > 0 {
pipe.Select(ctx, c.opt.DB)
}
@ -341,7 +422,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return nil
})
if err != nil {
return err
return fmt.Errorf("failed to initialize connection options: %w", err)
}
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
@ -363,6 +444,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
return nil
}
@ -481,6 +563,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
return c.opt.ReadTimeout
}
// context returns the context for the current connection.
// If the context timeout is enabled, it returns the original context.
// Otherwise, it returns a new background context.
func (c *baseClient) context(ctx context.Context) context.Context {
if c.opt.ContextTimeoutEnabled {
return ctx
}
return context.Background()
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
@ -633,13 +725,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
return nil
}
func (c *baseClient) context(ctx context.Context) context.Context {
if c.opt.ContextTimeoutEnabled {
return ctx
}
return context.Background()
}
//------------------------------------------------------------------------------
// Client is a Redis client representing a pool of zero or more underlying connections.
@ -650,7 +735,6 @@ func (c *baseClient) context(ctx context.Context) context.Context {
type Client struct {
*baseClient
cmdable
hooksMixin
}
// NewClient returns a client to the Redis Server specified by Options.
@ -689,7 +773,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
}
func (c *Client) Conn() *Conn {
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
}
// Do create a Cmd from the args and processes the cmd.
@ -822,10 +906,12 @@ type Conn struct {
baseClient
cmdable
statefulCmdable
hooksMixin
}
func newConn(opt *Options, connPool pool.Pooler) *Conn {
// newConn is a helper func to create a new Conn instance.
// the Conn instance is not thread-safe and should not be shared between goroutines.
// the parentHooks will be cloned, no need to clone before passing it.
func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
c := Conn{
baseClient: baseClient{
opt: opt,
@ -833,6 +919,10 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
},
}
if parentHooks != nil {
c.hooksMixin = parentHooks.clone()
}
c.cmdable = c.Process
c.statefulCmdable = c.Process
c.initHooks(hooks{

View File

@ -14,6 +14,7 @@ import (
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)
type redisHookError struct{}
@ -727,6 +728,174 @@ var _ = Describe("Dialer connection timeouts", func() {
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
})
})
var _ = Describe("Credentials Provider Priority", func() {
var client *redis.Client
var opt *redis.Options
var recorder *commandRecorder
BeforeEach(func() {
recorder = newCommandRecorder(10)
})
AfterEach(func() {
if client != nil {
Expect(client.Close()).NotTo(HaveOccurred())
}
})
It("should use streaming provider when available", func() {
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: streamingCreds,
updates: make(chan auth.Credentials, 1),
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
})
It("should use context provider when streaming provider is not available", func() {
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
})
It("should use regular provider when streaming and context providers are not available", func() {
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
})
It("should use username/password fields when no providers are set", func() {
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
})
It("should use empty credentials when nothing is set", func() {
opt = &redis.Options{}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// no pass, ok
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
Expect(recorder.Contains("AUTH")).To(BeFalse())
})
It("should handle credential updates from streaming provider", func() {
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
updatesChan := make(chan auth.Credentials, 1)
opt = &redis.Options{
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: initialCreds,
updates: updatesChan,
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
// Update credentials
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
close(updatesChan)
})
})
type mockStreamingProvider struct {
credentials auth.Credentials
err error
updates chan auth.Credentials
}
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
m.credentials = creds
listener.OnNext(creds)
}
}()
return m.credentials, func() (err error) {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow multiple closes from multiple listeners
}
}()
return
}, nil
}
var _ = Describe("Client creation", func() {
Context("simple client with nil options", func() {
It("panics", func() {

View File

@ -357,13 +357,17 @@ var _ = Describe("Redis Ring", func() {
ring.AddHook(&hook{
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
// skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcessPipeline")
err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(1))
Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcessPipeline")
@ -376,13 +380,17 @@ var _ = Describe("Redis Ring", func() {
shard.AddHook(&hook{
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
// skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(1))
Expect(len(cmds)).To(BeNumerically(">", 0))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
@ -416,14 +424,18 @@ var _ = Describe("Redis Ring", func() {
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
defer GinkgoRecover()
// skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(cmds).To(HaveLen(3))
Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcessPipeline")
err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(3))
Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcessPipeline")
@ -437,14 +449,18 @@ var _ = Describe("Redis Ring", func() {
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
defer GinkgoRecover()
// skip the connection initialization
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
return nil
}
Expect(cmds).To(HaveLen(3))
Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
err := hook(ctx, cmds)
Expect(cmds).To(HaveLen(3))
Expect(len(cmds)).To(BeNumerically(">=", 3))
Expect(cmds[1].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")

View File

@ -404,7 +404,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool
rdb.onClose = failover.Close
rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
@ -455,7 +455,6 @@ func masterReplicaDialer(
// SentinelClient is a client for a Redis Sentinel.
type SentinelClient struct {
*baseClient
hooksMixin
}
func NewSentinelClient(opt *Options) *SentinelClient {

3
tx.go
View File

@ -19,7 +19,6 @@ type Tx struct {
baseClient
cmdable
statefulCmdable
hooksMixin
}
func (c *Client) newTx() *Tx {
@ -27,8 +26,8 @@ func (c *Client) newTx() *Tx {
baseClient: baseClient{
opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool),
},
hooksMixin: c.hooksMixin.clone(),
},
}
tx.init()
return &tx