mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)
* wip * update documentation * add streamingcredentialsprovider in options * fix: put back option in pool creation * add package level comment * Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested * Change function type name Change CancelProviderFunc to UnsubscribeFunc * add tests * fix race in tests * fix example tests * wip, hooks refactor * fix build * update README.md * update wordlist * update README.md * refactor(auth): early returns in cred listener * fix(doctest): simulate some delay * feat(conn): add close hook on conn * fix(tests): simulate start/stop in mock credentials provider * fix(auth): don't double close the conn * docs(README): mark streaming credentials provider as experimental * fix(auth): streamline auth err proccess * fix(auth): check err on close conn * chore(entraid): use the repo under redis org
This commit is contained in:
10
.github/wordlist.txt
vendored
10
.github/wordlist.txt
vendored
@ -65,4 +65,12 @@ RedisGears
|
||||
RedisTimeseries
|
||||
RediSearch
|
||||
RawResult
|
||||
RawVal
|
||||
RawVal
|
||||
entra
|
||||
EntraID
|
||||
Entra
|
||||
OAuth
|
||||
Azure
|
||||
StreamingCredentialsProvider
|
||||
oauth
|
||||
entraid
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,4 +7,5 @@ testdata/*
|
||||
redis8tests.sh
|
||||
coverage.txt
|
||||
**/coverage.txt
|
||||
.vscode
|
||||
.vscode
|
||||
tmp/*
|
||||
|
121
README.md
121
README.md
@ -68,6 +68,7 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
|
||||
|
||||
- Redis commands except QUIT and SYNC.
|
||||
- Automatic connection pooling.
|
||||
- [StreamingCredentialsProvider (e.g. entra id, oauth)](#1-streaming-credentials-provider-highest-priority) (experimental)
|
||||
- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html).
|
||||
- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html).
|
||||
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
|
||||
@ -136,17 +137,121 @@ func ExampleClient() {
|
||||
}
|
||||
```
|
||||
|
||||
The above can be modified to specify the version of the RESP protocol by adding the `protocol`
|
||||
option to the `Options` struct:
|
||||
### Authentication
|
||||
|
||||
The Redis client supports multiple ways to provide authentication credentials, with a clear priority order. Here are the available options:
|
||||
|
||||
#### 1. Streaming Credentials Provider (Highest Priority) - Experimental feature
|
||||
|
||||
The streaming credentials provider allows for dynamic credential updates during the connection lifetime. This is particularly useful for managed identity services and token-based authentication.
|
||||
|
||||
```go
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
|
||||
})
|
||||
type StreamingCredentialsProvider interface {
|
||||
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
|
||||
}
|
||||
|
||||
type CredentialsListener interface {
|
||||
OnNext(credentials Credentials) // Called when credentials are updated
|
||||
OnError(err error) // Called when an error occurs
|
||||
}
|
||||
|
||||
type Credentials interface {
|
||||
BasicAuth() (username string, password string)
|
||||
RawCredentials() string
|
||||
}
|
||||
```
|
||||
|
||||
Example usage:
|
||||
```go
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
StreamingCredentialsProvider: &MyCredentialsProvider{},
|
||||
})
|
||||
```
|
||||
|
||||
**Note:** The streaming credentials provider can be used with [go-redis-entraid](https://github.com/redis/go-redis-entraid) to enable Entra ID (formerly Azure AD) authentication. This allows for seamless integration with Azure's managed identity services and token-based authentication.
|
||||
|
||||
Example with Entra ID:
|
||||
```go
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis-entraid"
|
||||
)
|
||||
|
||||
// Create an Entra ID credentials provider
|
||||
provider := entraid.NewDefaultAzureIdentityProvider()
|
||||
|
||||
// Configure Redis client with Entra ID authentication
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "your-redis-server.redis.cache.windows.net:6380",
|
||||
StreamingCredentialsProvider: provider,
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### 2. Context-based Credentials Provider
|
||||
|
||||
The context-based provider allows credentials to be determined at the time of each operation, using the context.
|
||||
|
||||
```go
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||
// Return username, password, and any error
|
||||
return "user", "pass", nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### 3. Regular Credentials Provider
|
||||
|
||||
A simple function-based provider that returns static credentials.
|
||||
|
||||
```go
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
// Return username and password
|
||||
return "user", "pass"
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### 4. Username/Password Fields (Lowest Priority)
|
||||
|
||||
The most basic way to provide credentials is through the `Username` and `Password` fields in the options.
|
||||
|
||||
```go
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
})
|
||||
```
|
||||
|
||||
#### Priority Order
|
||||
|
||||
The client will use credentials in the following priority order:
|
||||
1. Streaming Credentials Provider (if set)
|
||||
2. Context-based Credentials Provider (if set)
|
||||
3. Regular Credentials Provider (if set)
|
||||
4. Username/Password fields (if set)
|
||||
|
||||
If none of these are set, the client will attempt to connect without authentication.
|
||||
|
||||
### Protocol Version
|
||||
|
||||
The client supports both RESP2 and RESP3 protocols. You can specify the protocol version in the options:
|
||||
|
||||
```go
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
|
||||
})
|
||||
```
|
||||
|
||||
### Connecting via a redis url
|
||||
|
61
auth/auth.go
Normal file
61
auth/auth.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Package auth package provides authentication-related interfaces and types.
|
||||
// It also includes a basic implementation of credentials using username and password.
|
||||
package auth
|
||||
|
||||
// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider.
|
||||
// It is used to provide credentials for authentication.
|
||||
// The CredentialsListener is used to receive updates when the credentials change.
|
||||
type StreamingCredentialsProvider interface {
|
||||
// Subscribe subscribes to the credentials provider for updates.
|
||||
// It returns the current credentials, a cancel function to unsubscribe from the provider,
|
||||
// and an error if any.
|
||||
// TODO(ndyakov): Should we add context to the Subscribe method?
|
||||
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
|
||||
}
|
||||
|
||||
// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider.
|
||||
// It is used to unsubscribe from the provider when the credentials are no longer needed.
|
||||
type UnsubscribeFunc func() error
|
||||
|
||||
// CredentialsListener is an interface that defines the methods for a credentials listener.
|
||||
// It is used to receive updates when the credentials change.
|
||||
// The OnNext method is called when the credentials change.
|
||||
// The OnError method is called when an error occurs while requesting the credentials.
|
||||
type CredentialsListener interface {
|
||||
OnNext(credentials Credentials)
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
// Credentials is an interface that defines the methods for credentials.
|
||||
// It is used to provide the credentials for authentication.
|
||||
type Credentials interface {
|
||||
// BasicAuth returns the username and password for basic authentication.
|
||||
BasicAuth() (username string, password string)
|
||||
// RawCredentials returns the raw credentials as a string.
|
||||
// This can be used to extract the username and password from the raw credentials or
|
||||
// additional information if present in the token.
|
||||
RawCredentials() string
|
||||
}
|
||||
|
||||
type basicAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
// RawCredentials returns the raw credentials as a string.
|
||||
func (b *basicAuth) RawCredentials() string {
|
||||
return b.username + ":" + b.password
|
||||
}
|
||||
|
||||
// BasicAuth returns the username and password for basic authentication.
|
||||
func (b *basicAuth) BasicAuth() (username string, password string) {
|
||||
return b.username, b.password
|
||||
}
|
||||
|
||||
// NewBasicCredentials creates a new Credentials object from the given username and password.
|
||||
func NewBasicCredentials(username, password string) Credentials {
|
||||
return &basicAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
308
auth/auth_test.go
Normal file
308
auth/auth_test.go
Normal file
@ -0,0 +1,308 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockStreamingProvider struct {
|
||||
credentials Credentials
|
||||
err error
|
||||
updates chan Credentials
|
||||
}
|
||||
|
||||
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
|
||||
return &mockStreamingProvider{
|
||||
credentials: initialCreds,
|
||||
updates: make(chan Credentials, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
|
||||
if m.err != nil {
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
// Send initial credentials
|
||||
listener.OnNext(m.credentials)
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}()
|
||||
|
||||
return m.credentials, func() error {
|
||||
close(m.updates)
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestStreamingCredentialsProvider(t *testing.T) {
|
||||
t.Run("successful subscription", func(t *testing.T) {
|
||||
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||
provider := newMockStreamingProvider(initialCreds)
|
||||
|
||||
var receivedCreds []Credentials
|
||||
var receivedErrors []error
|
||||
var mu sync.Mutex
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
mu.Lock()
|
||||
receivedCreds = append(receivedCreds, creds)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cancel == nil {
|
||||
t.Fatal("expected cancel function to be non-nil")
|
||||
}
|
||||
if creds != initialCreds {
|
||||
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||
}
|
||||
if len(receivedCreds) != 1 {
|
||||
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
|
||||
}
|
||||
if receivedCreds[0] != initialCreds {
|
||||
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
|
||||
}
|
||||
if len(receivedErrors) != 0 {
|
||||
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||
}
|
||||
|
||||
// Send an update
|
||||
newCreds := NewBasicCredentials("user2", "pass2")
|
||||
provider.updates <- newCreds
|
||||
|
||||
// Wait for update to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
if len(receivedCreds) != 2 {
|
||||
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
|
||||
}
|
||||
if receivedCreds[1] != newCreds {
|
||||
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Cancel subscription
|
||||
if err := cancel(); err != nil {
|
||||
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subscription error", func(t *testing.T) {
|
||||
provider := &mockStreamingProvider{
|
||||
err: errors.New("subscription failed"),
|
||||
}
|
||||
|
||||
var receivedCreds []Credentials
|
||||
var receivedErrors []error
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
receivedCreds = append(receivedCreds, creds)
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if cancel != nil {
|
||||
t.Fatal("expected cancel function to be nil")
|
||||
}
|
||||
if creds != nil {
|
||||
t.Fatalf("expected nil credentials, got %v", creds)
|
||||
}
|
||||
if len(receivedCreds) != 0 {
|
||||
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
|
||||
}
|
||||
if len(receivedErrors) != 0 {
|
||||
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-auth error", func(t *testing.T) {
|
||||
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||
provider := newMockStreamingProvider(initialCreds)
|
||||
|
||||
reauthErr := errors.New("re-auth failed")
|
||||
var receivedErrors []error
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
return reauthErr
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cancel == nil {
|
||||
t.Fatal("expected cancel function to be non-nil")
|
||||
}
|
||||
if creds != initialCreds {
|
||||
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||
}
|
||||
if len(receivedErrors) != 1 {
|
||||
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
|
||||
}
|
||||
if receivedErrors[0] != reauthErr {
|
||||
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
|
||||
}
|
||||
|
||||
if err := cancel(); err != nil {
|
||||
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBasicCredentials(t *testing.T) {
|
||||
t.Run("basic auth", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
username, password := creds.BasicAuth()
|
||||
if username != "user1" {
|
||||
t.Fatalf("expected username 'user1', got '%s'", username)
|
||||
}
|
||||
if password != "pass1" {
|
||||
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raw credentials", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
raw := creds.RawCredentials()
|
||||
expected := "user1:pass1"
|
||||
if raw != expected {
|
||||
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty username", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("", "pass1")
|
||||
username, password := creds.BasicAuth()
|
||||
if username != "" {
|
||||
t.Fatalf("expected empty username, got '%s'", username)
|
||||
}
|
||||
if password != "pass1" {
|
||||
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReAuthCredentialsListener(t *testing.T) {
|
||||
t.Run("successful re-auth", func(t *testing.T) {
|
||||
var reAuthCalled bool
|
||||
var onErrCalled bool
|
||||
var receivedCreds Credentials
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
reAuthCalled = true
|
||||
receivedCreds = creds
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
},
|
||||
)
|
||||
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
listener.OnNext(creds)
|
||||
|
||||
if !reAuthCalled {
|
||||
t.Fatal("expected reAuth to be called")
|
||||
}
|
||||
if onErrCalled {
|
||||
t.Fatal("expected onErr not to be called")
|
||||
}
|
||||
if receivedCreds != creds {
|
||||
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-auth error", func(t *testing.T) {
|
||||
var reAuthCalled bool
|
||||
var onErrCalled bool
|
||||
var receivedErr error
|
||||
expectedErr := errors.New("re-auth failed")
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
reAuthCalled = true
|
||||
return expectedErr
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
receivedErr = err
|
||||
},
|
||||
)
|
||||
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
listener.OnNext(creds)
|
||||
|
||||
if !reAuthCalled {
|
||||
t.Fatal("expected reAuth to be called")
|
||||
}
|
||||
if !onErrCalled {
|
||||
t.Fatal("expected onErr to be called")
|
||||
}
|
||||
if receivedErr != expectedErr {
|
||||
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on error", func(t *testing.T) {
|
||||
var onErrCalled bool
|
||||
var receivedErr error
|
||||
expectedErr := errors.New("provider error")
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
receivedErr = err
|
||||
},
|
||||
)
|
||||
|
||||
listener.OnError(expectedErr)
|
||||
|
||||
if !onErrCalled {
|
||||
t.Fatal("expected onErr to be called")
|
||||
}
|
||||
if receivedErr != expectedErr {
|
||||
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil callbacks", func(t *testing.T) {
|
||||
listener := NewReAuthCredentialsListener(nil, nil)
|
||||
|
||||
// Should not panic
|
||||
listener.OnNext(NewBasicCredentials("user1", "pass1"))
|
||||
listener.OnError(errors.New("test error"))
|
||||
})
|
||||
}
|
47
auth/reauth_credentials_listener.go
Normal file
47
auth/reauth_credentials_listener.go
Normal file
@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
|
||||
// It is used to re-authenticate the credentials when they are updated.
|
||||
// It contains:
|
||||
// - reAuth: a function that takes the new credentials and returns an error if any.
|
||||
// - onErr: a function that takes an error and handles it.
|
||||
type ReAuthCredentialsListener struct {
|
||||
reAuth func(credentials Credentials) error
|
||||
onErr func(err error)
|
||||
}
|
||||
|
||||
// OnNext is called when the credentials are updated.
|
||||
// It calls the reAuth function with the new credentials.
|
||||
// If the reAuth function returns an error, it calls the onErr function with the error.
|
||||
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
|
||||
if c.reAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.reAuth(credentials)
|
||||
if err != nil {
|
||||
c.OnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OnError is called when an error occurs.
|
||||
// It can be called from both the credentials provider and the reAuth function.
|
||||
func (c *ReAuthCredentialsListener) OnError(err error) {
|
||||
if c.onErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.onErr(err)
|
||||
}
|
||||
|
||||
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
|
||||
// Implements the auth.CredentialsListener interface.
|
||||
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
|
||||
return &ReAuthCredentialsListener{
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
|
86
command_recorder_test.go
Normal file
86
command_recorder_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// commandRecorder records the last N commands executed by a Redis client.
|
||||
type commandRecorder struct {
|
||||
mu sync.Mutex
|
||||
commands []string
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// newCommandRecorder creates a new command recorder with the specified maximum size.
|
||||
func newCommandRecorder(maxSize int) *commandRecorder {
|
||||
return &commandRecorder{
|
||||
commands: make([]string, 0, maxSize),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Record adds a command to the recorder.
|
||||
func (r *commandRecorder) Record(cmd string) {
|
||||
cmd = strings.ToLower(cmd)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.commands = append(r.commands, cmd)
|
||||
if len(r.commands) > r.maxSize {
|
||||
r.commands = r.commands[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// LastCommands returns a copy of the recorded commands.
|
||||
func (r *commandRecorder) LastCommands() []string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return append([]string(nil), r.commands...)
|
||||
}
|
||||
|
||||
// Contains checks if the recorder contains a specific command.
|
||||
func (r *commandRecorder) Contains(cmd string) bool {
|
||||
cmd = strings.ToLower(cmd)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, c := range r.commands {
|
||||
if strings.Contains(c, cmd) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Hook returns a Redis hook that records commands.
|
||||
func (r *commandRecorder) Hook() redis.Hook {
|
||||
return &commandHook{recorder: r}
|
||||
}
|
||||
|
||||
// commandHook implements the redis.Hook interface to record commands.
|
||||
type commandHook struct {
|
||||
recorder *commandRecorder
|
||||
}
|
||||
|
||||
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
|
||||
return next
|
||||
}
|
||||
|
||||
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||
h.recorder.Record(cmd.String())
|
||||
return next(ctx, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
for _, cmd := range cmds {
|
||||
h.recorder.Record(cmd.String())
|
||||
}
|
||||
return next(ctx, cmds)
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ package example_commands_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@ -33,6 +34,7 @@ func ExampleClient_LPush_and_lrange() {
|
||||
}
|
||||
|
||||
fmt.Println(listSize)
|
||||
time.Sleep(10 * time.Millisecond) // Simulate some delay
|
||||
|
||||
value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result()
|
||||
if err != nil {
|
||||
|
@ -23,38 +23,47 @@ func (redisHook) DialHook(hook redis.DialHook) redis.DialHook {
|
||||
|
||||
func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
|
||||
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||
fmt.Printf("starting processing: <%s>\n", cmd)
|
||||
fmt.Printf("starting processing: <%v>\n", cmd.Args())
|
||||
err := hook(ctx, cmd)
|
||||
fmt.Printf("finished processing: <%s>\n", cmd)
|
||||
fmt.Printf("finished processing: <%v>\n", cmd.Args())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
fmt.Printf("pipeline starting processing: %v\n", cmds)
|
||||
names := make([]string, 0, len(cmds))
|
||||
for _, cmd := range cmds {
|
||||
names = append(names, fmt.Sprintf("%v", cmd.Args()))
|
||||
}
|
||||
fmt.Printf("pipeline starting processing: %v\n", names)
|
||||
err := hook(ctx, cmds)
|
||||
fmt.Printf("pipeline finished processing: %v\n", cmds)
|
||||
fmt.Printf("pipeline finished processing: %v\n", names)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func Example_instrumentation() {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ":6379",
|
||||
Addr: ":6379",
|
||||
DisableIdentity: true,
|
||||
})
|
||||
rdb.AddHook(redisHook{})
|
||||
|
||||
rdb.Ping(ctx)
|
||||
// Output: starting processing: <ping: >
|
||||
// Output:
|
||||
// starting processing: <[ping]>
|
||||
// dialing tcp :6379
|
||||
// finished dialing tcp :6379
|
||||
// finished processing: <ping: PONG>
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// finished processing: <[ping]>
|
||||
}
|
||||
|
||||
func ExamplePipeline_instrumentation() {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ":6379",
|
||||
Addr: ":6379",
|
||||
DisableIdentity: true,
|
||||
})
|
||||
rdb.AddHook(redisHook{})
|
||||
|
||||
@ -63,15 +72,19 @@ func ExamplePipeline_instrumentation() {
|
||||
pipe.Ping(ctx)
|
||||
return nil
|
||||
})
|
||||
// Output: pipeline starting processing: [ping: ping: ]
|
||||
// Output:
|
||||
// pipeline starting processing: [[ping] [ping]]
|
||||
// dialing tcp :6379
|
||||
// finished dialing tcp :6379
|
||||
// pipeline finished processing: [ping: PONG ping: PONG]
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// pipeline finished processing: [[ping] [ping]]
|
||||
}
|
||||
|
||||
func ExampleClient_Watch_instrumentation() {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ":6379",
|
||||
Addr: ":6379",
|
||||
DisableIdentity: true,
|
||||
})
|
||||
rdb.AddHook(redisHook{})
|
||||
|
||||
@ -81,14 +94,16 @@ func ExampleClient_Watch_instrumentation() {
|
||||
return nil
|
||||
}, "foo")
|
||||
// Output:
|
||||
// starting processing: <watch foo: >
|
||||
// starting processing: <[watch foo]>
|
||||
// dialing tcp :6379
|
||||
// finished dialing tcp :6379
|
||||
// finished processing: <watch foo: OK>
|
||||
// starting processing: <ping: >
|
||||
// finished processing: <ping: PONG>
|
||||
// starting processing: <ping: >
|
||||
// finished processing: <ping: PONG>
|
||||
// starting processing: <unwatch: >
|
||||
// finished processing: <unwatch: OK>
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// finished processing: <[watch foo]>
|
||||
// starting processing: <[ping]>
|
||||
// finished processing: <[ping]>
|
||||
// starting processing: <[ping]>
|
||||
// finished processing: <[ping]>
|
||||
// starting processing: <[unwatch]>
|
||||
// finished processing: <[unwatch]>
|
||||
}
|
||||
|
@ -23,6 +23,8 @@ type Conn struct {
|
||||
Inited bool
|
||||
pooled bool
|
||||
createdAt time.Time
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
|
||||
func NewConn(netConn net.Conn) *Conn {
|
||||
@ -46,6 +48,10 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetOnClose(fn func() error) {
|
||||
cn.onClose = fn
|
||||
}
|
||||
|
||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||
cn.netConn = netConn
|
||||
cn.rd.Reset(netConn)
|
||||
@ -95,6 +101,10 @@ func (cn *Conn) WithWriter(
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
}
|
||||
return cn.netConn.Close()
|
||||
}
|
||||
|
||||
|
@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) {
|
||||
},
|
||||
NewClient: func(opt *Options) *Client {
|
||||
c := NewClient(opt)
|
||||
c.baseClient.onClose = func() error {
|
||||
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
|
||||
closeCounter.increment(opt.Addr)
|
||||
return nil
|
||||
}
|
||||
})
|
||||
return c
|
||||
},
|
||||
})
|
||||
@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) {
|
||||
}
|
||||
createCounter.increment(opt.Addr)
|
||||
c := NewClient(opt)
|
||||
c.baseClient.onClose = func() error {
|
||||
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
|
||||
closeCounter.increment(opt.Addr)
|
||||
return nil
|
||||
}
|
||||
})
|
||||
return c
|
||||
},
|
||||
})
|
||||
|
129
options.go
129
options.go
@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
@ -29,10 +30,13 @@ type Limiter interface {
|
||||
|
||||
// Options keeps the settings to set up redis connection.
|
||||
type Options struct {
|
||||
// The network type, either tcp or unix.
|
||||
// Default is tcp.
|
||||
|
||||
// Network type, either tcp or unix.
|
||||
//
|
||||
// default: is tcp.
|
||||
Network string
|
||||
// host:port address.
|
||||
|
||||
// Addr is the address formated as host:port
|
||||
Addr string
|
||||
|
||||
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
|
||||
@ -46,17 +50,21 @@ type Options struct {
|
||||
OnConnect func(ctx context.Context, cn *Conn) error
|
||||
|
||||
// Protocol 2 or 3. Use the version to negotiate RESP version with redis-server.
|
||||
// Default is 3.
|
||||
//
|
||||
// default: 3.
|
||||
Protocol int
|
||||
// Use the specified Username to authenticate the current connection
|
||||
|
||||
// Username is used to authenticate the current connection
|
||||
// with one of the connections defined in the ACL list when connecting
|
||||
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
|
||||
Username string
|
||||
// Optional password. Must match the password specified in the
|
||||
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
||||
|
||||
// Password is an optional password. Must match the password specified in the
|
||||
// `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
||||
// or the User Password when connecting to a Redis 6.0 instance, or greater,
|
||||
// that is using the Redis ACL system.
|
||||
Password string
|
||||
|
||||
// CredentialsProvider allows the username and password to be updated
|
||||
// before reconnecting. It should return the current username and password.
|
||||
CredentialsProvider func() (username string, password string)
|
||||
@ -67,85 +75,126 @@ type Options struct {
|
||||
// There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
|
||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||
|
||||
// Database to be selected after connecting to the server.
|
||||
// StreamingCredentialsProvider is used to retrieve the credentials
|
||||
// for the connection from an external source. Those credentials may change
|
||||
// during the connection lifetime. This is useful for managed identity
|
||||
// scenarios where the credentials are retrieved from an external source.
|
||||
//
|
||||
// Currently, this is a placeholder for the future implementation.
|
||||
StreamingCredentialsProvider auth.StreamingCredentialsProvider
|
||||
|
||||
// DB is the database to be selected after connecting to the server.
|
||||
DB int
|
||||
|
||||
// Maximum number of retries before giving up.
|
||||
// Default is 3 retries; -1 (not 0) disables retries.
|
||||
// MaxRetries is the maximum number of retries before giving up.
|
||||
// -1 (not 0) disables retries.
|
||||
//
|
||||
// default: 3 retries
|
||||
MaxRetries int
|
||||
// Minimum backoff between each retry.
|
||||
// Default is 8 milliseconds; -1 disables backoff.
|
||||
|
||||
// MinRetryBackoff is the minimum backoff between each retry.
|
||||
// -1 disables backoff.
|
||||
//
|
||||
// default: 8 milliseconds
|
||||
MinRetryBackoff time.Duration
|
||||
// Maximum backoff between each retry.
|
||||
// Default is 512 milliseconds; -1 disables backoff.
|
||||
|
||||
// MaxRetryBackoff is the maximum backoff between each retry.
|
||||
// -1 disables backoff.
|
||||
// default: 512 milliseconds;
|
||||
MaxRetryBackoff time.Duration
|
||||
|
||||
// Dial timeout for establishing new connections.
|
||||
// Default is 5 seconds.
|
||||
// DialTimeout for establishing new connections.
|
||||
//
|
||||
// default: 5 seconds
|
||||
DialTimeout time.Duration
|
||||
// Timeout for socket reads. If reached, commands will fail
|
||||
|
||||
// ReadTimeout for socket reads. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
// - `0` - default timeout (3 seconds).
|
||||
// - `-1` - no timeout (block indefinitely).
|
||||
// - `-2` - disables SetReadDeadline calls completely.
|
||||
//
|
||||
// - `-1` - no timeout (block indefinitely).
|
||||
// - `-2` - disables SetReadDeadline calls completely.
|
||||
//
|
||||
// default: 3 seconds
|
||||
ReadTimeout time.Duration
|
||||
// Timeout for socket writes. If reached, commands will fail
|
||||
|
||||
// WriteTimeout for socket writes. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
// - `0` - default timeout (3 seconds).
|
||||
// - `-1` - no timeout (block indefinitely).
|
||||
// - `-2` - disables SetWriteDeadline calls completely.
|
||||
//
|
||||
// - `-1` - no timeout (block indefinitely).
|
||||
// - `-2` - disables SetWriteDeadline calls completely.
|
||||
//
|
||||
// default: 3 seconds
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines.
|
||||
// See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts
|
||||
ContextTimeoutEnabled bool
|
||||
|
||||
// Type of connection pool.
|
||||
// true for FIFO pool, false for LIFO pool.
|
||||
// PoolFIFO type of connection pool.
|
||||
//
|
||||
// - true for FIFO pool
|
||||
// - false for LIFO pool.
|
||||
//
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
PoolFIFO bool
|
||||
// Base number of socket connections.
|
||||
|
||||
// PoolSize is the base number of socket connections.
|
||||
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
|
||||
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
|
||||
// you can limit it through MaxActiveConns
|
||||
//
|
||||
// default: 10 * runtime.GOMAXPROCS(0)
|
||||
PoolSize int
|
||||
// Amount of time client waits for connection if all connections
|
||||
|
||||
// PoolTimeout is the amount of time client waits for connection if all connections
|
||||
// are busy before returning an error.
|
||||
// Default is ReadTimeout + 1 second.
|
||||
//
|
||||
// default: ReadTimeout + 1 second
|
||||
PoolTimeout time.Duration
|
||||
// Minimum number of idle connections which is useful when establishing
|
||||
// new connection is slow.
|
||||
// Default is 0. the idle connections are not closed by default.
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections which is useful when establishing
|
||||
// new connection is slow. The idle connections are not closed by default.
|
||||
//
|
||||
// default: 0
|
||||
MinIdleConns int
|
||||
// Maximum number of idle connections.
|
||||
// Default is 0. the idle connections are not closed by default.
|
||||
|
||||
// MaxIdleConns is the maximum number of idle connections.
|
||||
// The idle connections are not closed by default.
|
||||
//
|
||||
// default: 0
|
||||
MaxIdleConns int
|
||||
// Maximum number of connections allocated by the pool at a given time.
|
||||
|
||||
// MaxActiveConns is the maximum number of connections allocated by the pool at a given time.
|
||||
// When zero, there is no limit on the number of connections in the pool.
|
||||
// If the pool is full, the next call to Get() will block until a connection is released.
|
||||
MaxActiveConns int
|
||||
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
|
||||
// Should be less than server's timeout.
|
||||
//
|
||||
// Expired connections may be closed lazily before reuse.
|
||||
// If d <= 0, connections are not closed due to a connection's idle time.
|
||||
// -1 disables idle timeout check.
|
||||
//
|
||||
// Default is 30 minutes. -1 disables idle timeout check.
|
||||
// default: 30 minutes
|
||||
ConnMaxIdleTime time.Duration
|
||||
|
||||
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
|
||||
//
|
||||
// Expired connections may be closed lazily before reuse.
|
||||
// If <= 0, connections are not closed due to a connection's age.
|
||||
//
|
||||
// Default is to not close idle connections.
|
||||
// default: 0
|
||||
ConnMaxLifetime time.Duration
|
||||
|
||||
// TLS Config to use. When set, TLS will be negotiated.
|
||||
// TLSConfig to use. When set, TLS will be negotiated.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// Limiter interface used to implement circuit breaker or rate limiter.
|
||||
Limiter Limiter
|
||||
|
||||
// Enables read only queries on slave/follower nodes.
|
||||
// readOnly enables read only queries on slave/follower nodes.
|
||||
readOnly bool
|
||||
|
||||
// DisableIndentity - Disable set-lib on connect.
|
||||
@ -161,9 +210,11 @@ type Options struct {
|
||||
DisableIdentity bool
|
||||
|
||||
// Add suffix to client name. Default is empty.
|
||||
// IdentitySuffix - add suffix to client name.
|
||||
IdentitySuffix string
|
||||
|
||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
|
||||
UnstableResp3 bool
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/hashtag"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
@ -66,11 +67,12 @@ type ClusterOptions struct {
|
||||
|
||||
OnConnect func(ctx context.Context, cn *Conn) error
|
||||
|
||||
Protocol int
|
||||
Username string
|
||||
Password string
|
||||
CredentialsProvider func() (username string, password string)
|
||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||
Protocol int
|
||||
Username string
|
||||
Password string
|
||||
CredentialsProvider func() (username string, password string)
|
||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||
StreamingCredentialsProvider auth.StreamingCredentialsProvider
|
||||
|
||||
MaxRetries int
|
||||
MinRetryBackoff time.Duration
|
||||
@ -292,11 +294,12 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
Dialer: opt.Dialer,
|
||||
OnConnect: opt.OnConnect,
|
||||
|
||||
Protocol: opt.Protocol,
|
||||
Username: opt.Username,
|
||||
Password: opt.Password,
|
||||
CredentialsProvider: opt.CredentialsProvider,
|
||||
CredentialsProviderContext: opt.CredentialsProviderContext,
|
||||
Protocol: opt.Protocol,
|
||||
Username: opt.Username,
|
||||
Password: opt.Password,
|
||||
CredentialsProvider: opt.CredentialsProvider,
|
||||
CredentialsProviderContext: opt.CredentialsProviderContext,
|
||||
StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
|
||||
|
||||
MaxRetries: opt.MaxRetries,
|
||||
MinRetryBackoff: opt.MinRetryBackoff,
|
||||
|
@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient(
|
||||
func (s *clusterScenario) Close() error {
|
||||
ctx := context.TODO()
|
||||
for _, master := range s.masters() {
|
||||
if master == nil {
|
||||
continue
|
||||
}
|
||||
err := master.FlushAll(ctx).Err()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
||||
})
|
||||
|
||||
It("should CFCount", Label("cuckoo", "cfcount"), func() {
|
||||
err := client.CFAdd(ctx, "testcf1", "item1").Err()
|
||||
client.CFAdd(ctx, "testcf1", "item1")
|
||||
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cnt).To(BeEquivalentTo(int64(1)))
|
||||
@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
||||
NoCreate: true,
|
||||
}
|
||||
|
||||
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||
_, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
args = &redis.CFInsertOptions{
|
||||
@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
||||
NoCreate: false,
|
||||
}
|
||||
|
||||
result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(len(result)).To(BeEquivalentTo(3))
|
||||
})
|
||||
|
152
redis.go
152
redis.go
@ -9,6 +9,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/hscan"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
@ -203,6 +204,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
||||
type baseClient struct {
|
||||
opt *Options
|
||||
connPool pool.Pooler
|
||||
hooksMixin
|
||||
|
||||
onClose func() error // hook called when client is closed
|
||||
}
|
||||
@ -282,30 +284,107 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||
return auth.NewReAuthCredentialsListener(
|
||||
c.reAuthConnection(poolCn),
|
||||
c.onAuthenticationErr(poolCn),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||
return func(credentials auth.Credentials) error {
|
||||
var err error
|
||||
username, password := credentials.BasicAuth()
|
||||
ctx := context.Background()
|
||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||
// hooksMixin are intentionally empty here
|
||||
cn := newConn(c.opt, connPool, nil)
|
||||
|
||||
if username != "" {
|
||||
err = cn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = cn.Auth(ctx, password).Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
// Close the connection to force a reconnection.
|
||||
err := c.connPool.CloseConn(poolCn)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err)
|
||||
// try to close the network connection directly
|
||||
// so that no resource is leaked
|
||||
err := poolCn.Close()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||
onClose := c.onClose
|
||||
return func() error {
|
||||
var firstErr error
|
||||
err := newOnClose()
|
||||
// Even if we have an error we would like to execute the onClose hook
|
||||
// if it exists. We will return the first error that occurred.
|
||||
// This is to keep error handling consistent with the rest of the code.
|
||||
if err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
if onClose != nil {
|
||||
err = onClose()
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if cn.Inited {
|
||||
return nil
|
||||
}
|
||||
cn.Inited = true
|
||||
|
||||
var err error
|
||||
username, password := c.opt.Username, c.opt.Password
|
||||
if c.opt.CredentialsProviderContext != nil {
|
||||
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
||||
return err
|
||||
cn.Inited = true
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
}
|
||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||
}
|
||||
} else if c.opt.CredentialsProvider != nil {
|
||||
username, password = c.opt.CredentialsProvider()
|
||||
} else if c.opt.Username != "" || c.opt.Password != "" {
|
||||
username, password = c.opt.Username, c.opt.Password
|
||||
}
|
||||
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool)
|
||||
|
||||
var auth bool
|
||||
// for redis-server versions that do not support the HELLO command,
|
||||
// RESP2 will continue to be used.
|
||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
auth = true
|
||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
// Authentication successful with HELLO command
|
||||
} else if !isRedisError(err) {
|
||||
// When the server responds with the RESP protocol and the result is not a normal
|
||||
// execution result of the HELLO command, we consider it to be an indication that
|
||||
@ -315,17 +394,19 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
// with different error string results for unsupported commands, making it
|
||||
// difficult to rely on error strings to determine all results.
|
||||
return err
|
||||
} else if password != "" {
|
||||
// Try legacy AUTH command if HELLO failed
|
||||
if username != "" {
|
||||
err = conn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = conn.Auth(ctx, password).Err()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
if !auth && password != "" {
|
||||
if username != "" {
|
||||
pipe.AuthACL(ctx, username, password)
|
||||
} else {
|
||||
pipe.Auth(ctx, password)
|
||||
}
|
||||
}
|
||||
|
||||
if c.opt.DB > 0 {
|
||||
pipe.Select(ctx, c.opt.DB)
|
||||
}
|
||||
@ -341,7 +422,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
}
|
||||
|
||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||
@ -363,6 +444,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -481,6 +563,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
||||
return c.opt.ReadTimeout
|
||||
}
|
||||
|
||||
// context returns the context for the current connection.
|
||||
// If the context timeout is enabled, it returns the original context.
|
||||
// Otherwise, it returns a new background context.
|
||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
if c.opt.ContextTimeoutEnabled {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// Close closes the client, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Client, as the Client is meant to be
|
||||
@ -633,13 +725,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
if c.opt.ContextTimeoutEnabled {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Client is a Redis client representing a pool of zero or more underlying connections.
|
||||
@ -650,7 +735,6 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
type Client struct {
|
||||
*baseClient
|
||||
cmdable
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
// NewClient returns a client to the Redis Server specified by Options.
|
||||
@ -689,7 +773,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Conn() *Conn {
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
|
||||
}
|
||||
|
||||
// Do create a Cmd from the args and processes the cmd.
|
||||
@ -822,10 +906,12 @@ type Conn struct {
|
||||
baseClient
|
||||
cmdable
|
||||
statefulCmdable
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
||||
// newConn is a helper func to create a new Conn instance.
|
||||
// the Conn instance is not thread-safe and should not be shared between goroutines.
|
||||
// the parentHooks will be cloned, no need to clone before passing it.
|
||||
func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
|
||||
c := Conn{
|
||||
baseClient: baseClient{
|
||||
opt: opt,
|
||||
@ -833,6 +919,10 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
||||
},
|
||||
}
|
||||
|
||||
if parentHooks != nil {
|
||||
c.hooksMixin = parentHooks.clone()
|
||||
}
|
||||
|
||||
c.cmdable = c.Process
|
||||
c.statefulCmdable = c.Process
|
||||
c.initHooks(hooks{
|
||||
|
169
redis_test.go
169
redis_test.go
@ -14,6 +14,7 @@ import (
|
||||
. "github.com/bsm/gomega"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
||||
type redisHookError struct{}
|
||||
@ -727,6 +728,174 @@ var _ = Describe("Dialer connection timeouts", func() {
|
||||
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Credentials Provider Priority", func() {
|
||||
var client *redis.Client
|
||||
var opt *redis.Options
|
||||
var recorder *commandRecorder
|
||||
|
||||
BeforeEach(func() {
|
||||
recorder = newCommandRecorder(10)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if client != nil {
|
||||
Expect(client.Close()).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
It("should use streaming provider when available", func() {
|
||||
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
|
||||
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
username, password := providerCreds.BasicAuth()
|
||||
return username, password
|
||||
},
|
||||
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||
username, password := ctxCreds.BasicAuth()
|
||||
return username, password, nil
|
||||
},
|
||||
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||
credentials: streamingCreds,
|
||||
updates: make(chan auth.Credentials, 1),
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use context provider when streaming provider is not available", func() {
|
||||
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
username, password := providerCreds.BasicAuth()
|
||||
return username, password
|
||||
},
|
||||
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||
username, password := ctxCreds.BasicAuth()
|
||||
return username, password, nil
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use regular provider when streaming and context providers are not available", func() {
|
||||
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
username, password := providerCreds.BasicAuth()
|
||||
return username, password
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use username/password fields when no providers are set", func() {
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use empty credentials when nothing is set", func() {
|
||||
opt = &redis.Options{}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// no pass, ok
|
||||
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH")).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should handle credential updates from streaming provider", func() {
|
||||
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
|
||||
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
|
||||
updatesChan := make(chan auth.Credentials, 1)
|
||||
|
||||
opt = &redis.Options{
|
||||
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||
credentials: initialCreds,
|
||||
updates: updatesChan,
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
|
||||
|
||||
// Update credentials
|
||||
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
|
||||
close(updatesChan)
|
||||
})
|
||||
})
|
||||
|
||||
type mockStreamingProvider struct {
|
||||
credentials auth.Credentials
|
||||
err error
|
||||
updates chan auth.Credentials
|
||||
}
|
||||
|
||||
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
|
||||
if m.err != nil {
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
m.credentials = creds
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}()
|
||||
|
||||
return m.credentials, func() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// this is just a mock:
|
||||
// allow multiple closes from multiple listeners
|
||||
}
|
||||
}()
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ = Describe("Client creation", func() {
|
||||
Context("simple client with nil options", func() {
|
||||
It("panics", func() {
|
||||
|
32
ring_test.go
32
ring_test.go
@ -357,13 +357,17 @@ var _ = Describe("Redis Ring", func() {
|
||||
ring.AddHook(&hook{
|
||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
Expect(cmds).To(HaveLen(1))
|
||||
// skip the connection initialization
|
||||
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||
return nil
|
||||
}
|
||||
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||
Expect(cmds[0].String()).To(Equal("ping: "))
|
||||
stack = append(stack, "ring.BeforeProcessPipeline")
|
||||
|
||||
err := hook(ctx, cmds)
|
||||
|
||||
Expect(cmds).To(HaveLen(1))
|
||||
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||
Expect(cmds[0].String()).To(Equal("ping: PONG"))
|
||||
stack = append(stack, "ring.AfterProcessPipeline")
|
||||
|
||||
@ -376,13 +380,17 @@ var _ = Describe("Redis Ring", func() {
|
||||
shard.AddHook(&hook{
|
||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
Expect(cmds).To(HaveLen(1))
|
||||
// skip the connection initialization
|
||||
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||
return nil
|
||||
}
|
||||
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||
Expect(cmds[0].String()).To(Equal("ping: "))
|
||||
stack = append(stack, "shard.BeforeProcessPipeline")
|
||||
|
||||
err := hook(ctx, cmds)
|
||||
|
||||
Expect(cmds).To(HaveLen(1))
|
||||
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||
Expect(cmds[0].String()).To(Equal("ping: PONG"))
|
||||
stack = append(stack, "shard.AfterProcessPipeline")
|
||||
|
||||
@ -416,14 +424,18 @@ var _ = Describe("Redis Ring", func() {
|
||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
defer GinkgoRecover()
|
||||
// skip the connection initialization
|
||||
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||
return nil
|
||||
}
|
||||
|
||||
Expect(cmds).To(HaveLen(3))
|
||||
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||
Expect(cmds[1].String()).To(Equal("ping: "))
|
||||
stack = append(stack, "ring.BeforeProcessPipeline")
|
||||
|
||||
err := hook(ctx, cmds)
|
||||
|
||||
Expect(cmds).To(HaveLen(3))
|
||||
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||
Expect(cmds[1].String()).To(Equal("ping: PONG"))
|
||||
stack = append(stack, "ring.AfterProcessPipeline")
|
||||
|
||||
@ -437,14 +449,18 @@ var _ = Describe("Redis Ring", func() {
|
||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
defer GinkgoRecover()
|
||||
// skip the connection initialization
|
||||
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||
return nil
|
||||
}
|
||||
|
||||
Expect(cmds).To(HaveLen(3))
|
||||
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||
Expect(cmds[1].String()).To(Equal("ping: "))
|
||||
stack = append(stack, "shard.BeforeProcessPipeline")
|
||||
|
||||
err := hook(ctx, cmds)
|
||||
|
||||
Expect(cmds).To(HaveLen(3))
|
||||
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||
Expect(cmds[1].String()).To(Equal("ping: PONG"))
|
||||
stack = append(stack, "shard.AfterProcessPipeline")
|
||||
|
||||
|
@ -404,7 +404,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
|
||||
connPool = newConnPool(opt, rdb.dialHook)
|
||||
rdb.connPool = connPool
|
||||
rdb.onClose = failover.Close
|
||||
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||
|
||||
failover.mu.Lock()
|
||||
failover.onFailover = func(ctx context.Context, addr string) {
|
||||
@ -455,7 +455,6 @@ func masterReplicaDialer(
|
||||
// SentinelClient is a client for a Redis Sentinel.
|
||||
type SentinelClient struct {
|
||||
*baseClient
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
func NewSentinelClient(opt *Options) *SentinelClient {
|
||||
|
7
tx.go
7
tx.go
@ -19,16 +19,15 @@ type Tx struct {
|
||||
baseClient
|
||||
cmdable
|
||||
statefulCmdable
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
func (c *Client) newTx() *Tx {
|
||||
tx := Tx{
|
||||
baseClient: baseClient{
|
||||
opt: c.opt,
|
||||
connPool: pool.NewStickyConnPool(c.connPool),
|
||||
opt: c.opt,
|
||||
connPool: pool.NewStickyConnPool(c.connPool),
|
||||
hooksMixin: c.hooksMixin.clone(),
|
||||
},
|
||||
hooksMixin: c.hooksMixin.clone(),
|
||||
}
|
||||
tx.init()
|
||||
return &tx
|
||||
|
Reference in New Issue
Block a user