mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)
* wip * update documentation * add streamingcredentialsprovider in options * fix: put back option in pool creation * add package level comment * Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested * Change function type name Change CancelProviderFunc to UnsubscribeFunc * add tests * fix race in tests * fix example tests * wip, hooks refactor * fix build * update README.md * update wordlist * update README.md * refactor(auth): early returns in cred listener * fix(doctest): simulate some delay * feat(conn): add close hook on conn * fix(tests): simulate start/stop in mock credentials provider * fix(auth): don't double close the conn * docs(README): mark streaming credentials provider as experimental * fix(auth): streamline auth err proccess * fix(auth): check err on close conn * chore(entraid): use the repo under redis org
This commit is contained in:
8
.github/wordlist.txt
vendored
8
.github/wordlist.txt
vendored
@ -66,3 +66,11 @@ RedisTimeseries
|
|||||||
RediSearch
|
RediSearch
|
||||||
RawResult
|
RawResult
|
||||||
RawVal
|
RawVal
|
||||||
|
entra
|
||||||
|
EntraID
|
||||||
|
Entra
|
||||||
|
OAuth
|
||||||
|
Azure
|
||||||
|
StreamingCredentialsProvider
|
||||||
|
oauth
|
||||||
|
entraid
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ redis8tests.sh
|
|||||||
coverage.txt
|
coverage.txt
|
||||||
**/coverage.txt
|
**/coverage.txt
|
||||||
.vscode
|
.vscode
|
||||||
|
tmp/*
|
||||||
|
121
README.md
121
README.md
@ -68,6 +68,7 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
|
|||||||
|
|
||||||
- Redis commands except QUIT and SYNC.
|
- Redis commands except QUIT and SYNC.
|
||||||
- Automatic connection pooling.
|
- Automatic connection pooling.
|
||||||
|
- [StreamingCredentialsProvider (e.g. entra id, oauth)](#1-streaming-credentials-provider-highest-priority) (experimental)
|
||||||
- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html).
|
- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html).
|
||||||
- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html).
|
- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html).
|
||||||
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
|
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
|
||||||
@ -136,17 +137,121 @@ func ExampleClient() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
The above can be modified to specify the version of the RESP protocol by adding the `protocol`
|
### Authentication
|
||||||
option to the `Options` struct:
|
|
||||||
|
The Redis client supports multiple ways to provide authentication credentials, with a clear priority order. Here are the available options:
|
||||||
|
|
||||||
|
#### 1. Streaming Credentials Provider (Highest Priority) - Experimental feature
|
||||||
|
|
||||||
|
The streaming credentials provider allows for dynamic credential updates during the connection lifetime. This is particularly useful for managed identity services and token-based authentication.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
rdb := redis.NewClient(&redis.Options{
|
type StreamingCredentialsProvider interface {
|
||||||
Addr: "localhost:6379",
|
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
|
||||||
Password: "", // no password set
|
}
|
||||||
DB: 0, // use default DB
|
|
||||||
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
|
|
||||||
})
|
|
||||||
|
|
||||||
|
type CredentialsListener interface {
|
||||||
|
OnNext(credentials Credentials) // Called when credentials are updated
|
||||||
|
OnError(err error) // Called when an error occurs
|
||||||
|
}
|
||||||
|
|
||||||
|
type Credentials interface {
|
||||||
|
BasicAuth() (username string, password string)
|
||||||
|
RawCredentials() string
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```go
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "localhost:6379",
|
||||||
|
StreamingCredentialsProvider: &MyCredentialsProvider{},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** The streaming credentials provider can be used with [go-redis-entraid](https://github.com/redis/go-redis-entraid) to enable Entra ID (formerly Azure AD) authentication. This allows for seamless integration with Azure's managed identity services and token-based authentication.
|
||||||
|
|
||||||
|
Example with Entra ID:
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/redis/go-redis-entraid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create an Entra ID credentials provider
|
||||||
|
provider := entraid.NewDefaultAzureIdentityProvider()
|
||||||
|
|
||||||
|
// Configure Redis client with Entra ID authentication
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "your-redis-server.redis.cache.windows.net:6380",
|
||||||
|
StreamingCredentialsProvider: provider,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Context-based Credentials Provider
|
||||||
|
|
||||||
|
The context-based provider allows credentials to be determined at the time of each operation, using the context.
|
||||||
|
|
||||||
|
```go
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "localhost:6379",
|
||||||
|
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||||
|
// Return username, password, and any error
|
||||||
|
return "user", "pass", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Regular Credentials Provider
|
||||||
|
|
||||||
|
A simple function-based provider that returns static credentials.
|
||||||
|
|
||||||
|
```go
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "localhost:6379",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
// Return username and password
|
||||||
|
return "user", "pass"
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Username/Password Fields (Lowest Priority)
|
||||||
|
|
||||||
|
The most basic way to provide credentials is through the `Username` and `Password` fields in the options.
|
||||||
|
|
||||||
|
```go
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "localhost:6379",
|
||||||
|
Username: "user",
|
||||||
|
Password: "pass",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Priority Order
|
||||||
|
|
||||||
|
The client will use credentials in the following priority order:
|
||||||
|
1. Streaming Credentials Provider (if set)
|
||||||
|
2. Context-based Credentials Provider (if set)
|
||||||
|
3. Regular Credentials Provider (if set)
|
||||||
|
4. Username/Password fields (if set)
|
||||||
|
|
||||||
|
If none of these are set, the client will attempt to connect without authentication.
|
||||||
|
|
||||||
|
### Protocol Version
|
||||||
|
|
||||||
|
The client supports both RESP2 and RESP3 protocols. You can specify the protocol version in the options:
|
||||||
|
|
||||||
|
```go
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "localhost:6379",
|
||||||
|
Password: "", // no password set
|
||||||
|
DB: 0, // use default DB
|
||||||
|
Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
### Connecting via a redis url
|
### Connecting via a redis url
|
||||||
|
61
auth/auth.go
Normal file
61
auth/auth.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// Package auth package provides authentication-related interfaces and types.
|
||||||
|
// It also includes a basic implementation of credentials using username and password.
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider.
|
||||||
|
// It is used to provide credentials for authentication.
|
||||||
|
// The CredentialsListener is used to receive updates when the credentials change.
|
||||||
|
type StreamingCredentialsProvider interface {
|
||||||
|
// Subscribe subscribes to the credentials provider for updates.
|
||||||
|
// It returns the current credentials, a cancel function to unsubscribe from the provider,
|
||||||
|
// and an error if any.
|
||||||
|
// TODO(ndyakov): Should we add context to the Subscribe method?
|
||||||
|
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider.
|
||||||
|
// It is used to unsubscribe from the provider when the credentials are no longer needed.
|
||||||
|
type UnsubscribeFunc func() error
|
||||||
|
|
||||||
|
// CredentialsListener is an interface that defines the methods for a credentials listener.
|
||||||
|
// It is used to receive updates when the credentials change.
|
||||||
|
// The OnNext method is called when the credentials change.
|
||||||
|
// The OnError method is called when an error occurs while requesting the credentials.
|
||||||
|
type CredentialsListener interface {
|
||||||
|
OnNext(credentials Credentials)
|
||||||
|
OnError(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Credentials is an interface that defines the methods for credentials.
|
||||||
|
// It is used to provide the credentials for authentication.
|
||||||
|
type Credentials interface {
|
||||||
|
// BasicAuth returns the username and password for basic authentication.
|
||||||
|
BasicAuth() (username string, password string)
|
||||||
|
// RawCredentials returns the raw credentials as a string.
|
||||||
|
// This can be used to extract the username and password from the raw credentials or
|
||||||
|
// additional information if present in the token.
|
||||||
|
RawCredentials() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type basicAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawCredentials returns the raw credentials as a string.
|
||||||
|
func (b *basicAuth) RawCredentials() string {
|
||||||
|
return b.username + ":" + b.password
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuth returns the username and password for basic authentication.
|
||||||
|
func (b *basicAuth) BasicAuth() (username string, password string) {
|
||||||
|
return b.username, b.password
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBasicCredentials creates a new Credentials object from the given username and password.
|
||||||
|
func NewBasicCredentials(username, password string) Credentials {
|
||||||
|
return &basicAuth{
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
}
|
||||||
|
}
|
308
auth/auth_test.go
Normal file
308
auth/auth_test.go
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStreamingProvider struct {
|
||||||
|
credentials Credentials
|
||||||
|
err error
|
||||||
|
updates chan Credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
|
||||||
|
return &mockStreamingProvider{
|
||||||
|
credentials: initialCreds,
|
||||||
|
updates: make(chan Credentials, 10),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, nil, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial credentials
|
||||||
|
listener.OnNext(m.credentials)
|
||||||
|
|
||||||
|
// Start goroutine to handle updates
|
||||||
|
go func() {
|
||||||
|
for creds := range m.updates {
|
||||||
|
listener.OnNext(creds)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return m.credentials, func() error {
|
||||||
|
close(m.updates)
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamingCredentialsProvider(t *testing.T) {
|
||||||
|
t.Run("successful subscription", func(t *testing.T) {
|
||||||
|
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||||
|
provider := newMockStreamingProvider(initialCreds)
|
||||||
|
|
||||||
|
var receivedCreds []Credentials
|
||||||
|
var receivedErrors []error
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
mu.Lock()
|
||||||
|
receivedCreds = append(receivedCreds, creds)
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
receivedErrors = append(receivedErrors, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds, cancel, err := provider.Subscribe(listener)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cancel == nil {
|
||||||
|
t.Fatal("expected cancel function to be non-nil")
|
||||||
|
}
|
||||||
|
if creds != initialCreds {
|
||||||
|
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||||
|
}
|
||||||
|
if len(receivedCreds) != 1 {
|
||||||
|
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
|
||||||
|
}
|
||||||
|
if receivedCreds[0] != initialCreds {
|
||||||
|
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
|
||||||
|
}
|
||||||
|
if len(receivedErrors) != 0 {
|
||||||
|
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send an update
|
||||||
|
newCreds := NewBasicCredentials("user2", "pass2")
|
||||||
|
provider.updates <- newCreds
|
||||||
|
|
||||||
|
// Wait for update to be processed
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
mu.Lock()
|
||||||
|
if len(receivedCreds) != 2 {
|
||||||
|
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
|
||||||
|
}
|
||||||
|
if receivedCreds[1] != newCreds {
|
||||||
|
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
// Cancel subscription
|
||||||
|
if err := cancel(); err != nil {
|
||||||
|
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subscription error", func(t *testing.T) {
|
||||||
|
provider := &mockStreamingProvider{
|
||||||
|
err: errors.New("subscription failed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var receivedCreds []Credentials
|
||||||
|
var receivedErrors []error
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
receivedCreds = append(receivedCreds, creds)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
receivedErrors = append(receivedErrors, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds, cancel, err := provider.Subscribe(listener)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if cancel != nil {
|
||||||
|
t.Fatal("expected cancel function to be nil")
|
||||||
|
}
|
||||||
|
if creds != nil {
|
||||||
|
t.Fatalf("expected nil credentials, got %v", creds)
|
||||||
|
}
|
||||||
|
if len(receivedCreds) != 0 {
|
||||||
|
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
|
||||||
|
}
|
||||||
|
if len(receivedErrors) != 0 {
|
||||||
|
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("re-auth error", func(t *testing.T) {
|
||||||
|
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||||
|
provider := newMockStreamingProvider(initialCreds)
|
||||||
|
|
||||||
|
reauthErr := errors.New("re-auth failed")
|
||||||
|
var receivedErrors []error
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
return reauthErr
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
receivedErrors = append(receivedErrors, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds, cancel, err := provider.Subscribe(listener)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cancel == nil {
|
||||||
|
t.Fatal("expected cancel function to be non-nil")
|
||||||
|
}
|
||||||
|
if creds != initialCreds {
|
||||||
|
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||||
|
}
|
||||||
|
if len(receivedErrors) != 1 {
|
||||||
|
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
|
||||||
|
}
|
||||||
|
if receivedErrors[0] != reauthErr {
|
||||||
|
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cancel(); err != nil {
|
||||||
|
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicCredentials(t *testing.T) {
|
||||||
|
t.Run("basic auth", func(t *testing.T) {
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
username, password := creds.BasicAuth()
|
||||||
|
if username != "user1" {
|
||||||
|
t.Fatalf("expected username 'user1', got '%s'", username)
|
||||||
|
}
|
||||||
|
if password != "pass1" {
|
||||||
|
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw credentials", func(t *testing.T) {
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
raw := creds.RawCredentials()
|
||||||
|
expected := "user1:pass1"
|
||||||
|
if raw != expected {
|
||||||
|
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty username", func(t *testing.T) {
|
||||||
|
creds := NewBasicCredentials("", "pass1")
|
||||||
|
username, password := creds.BasicAuth()
|
||||||
|
if username != "" {
|
||||||
|
t.Fatalf("expected empty username, got '%s'", username)
|
||||||
|
}
|
||||||
|
if password != "pass1" {
|
||||||
|
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReAuthCredentialsListener(t *testing.T) {
|
||||||
|
t.Run("successful re-auth", func(t *testing.T) {
|
||||||
|
var reAuthCalled bool
|
||||||
|
var onErrCalled bool
|
||||||
|
var receivedCreds Credentials
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
reAuthCalled = true
|
||||||
|
receivedCreds = creds
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
onErrCalled = true
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
listener.OnNext(creds)
|
||||||
|
|
||||||
|
if !reAuthCalled {
|
||||||
|
t.Fatal("expected reAuth to be called")
|
||||||
|
}
|
||||||
|
if onErrCalled {
|
||||||
|
t.Fatal("expected onErr not to be called")
|
||||||
|
}
|
||||||
|
if receivedCreds != creds {
|
||||||
|
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("re-auth error", func(t *testing.T) {
|
||||||
|
var reAuthCalled bool
|
||||||
|
var onErrCalled bool
|
||||||
|
var receivedErr error
|
||||||
|
expectedErr := errors.New("re-auth failed")
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
reAuthCalled = true
|
||||||
|
return expectedErr
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
onErrCalled = true
|
||||||
|
receivedErr = err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds := NewBasicCredentials("user1", "pass1")
|
||||||
|
listener.OnNext(creds)
|
||||||
|
|
||||||
|
if !reAuthCalled {
|
||||||
|
t.Fatal("expected reAuth to be called")
|
||||||
|
}
|
||||||
|
if !onErrCalled {
|
||||||
|
t.Fatal("expected onErr to be called")
|
||||||
|
}
|
||||||
|
if receivedErr != expectedErr {
|
||||||
|
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on error", func(t *testing.T) {
|
||||||
|
var onErrCalled bool
|
||||||
|
var receivedErr error
|
||||||
|
expectedErr := errors.New("provider error")
|
||||||
|
|
||||||
|
listener := NewReAuthCredentialsListener(
|
||||||
|
func(creds Credentials) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(err error) {
|
||||||
|
onErrCalled = true
|
||||||
|
receivedErr = err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
listener.OnError(expectedErr)
|
||||||
|
|
||||||
|
if !onErrCalled {
|
||||||
|
t.Fatal("expected onErr to be called")
|
||||||
|
}
|
||||||
|
if receivedErr != expectedErr {
|
||||||
|
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil callbacks", func(t *testing.T) {
|
||||||
|
listener := NewReAuthCredentialsListener(nil, nil)
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
listener.OnNext(NewBasicCredentials("user1", "pass1"))
|
||||||
|
listener.OnError(errors.New("test error"))
|
||||||
|
})
|
||||||
|
}
|
47
auth/reauth_credentials_listener.go
Normal file
47
auth/reauth_credentials_listener.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
|
||||||
|
// It is used to re-authenticate the credentials when they are updated.
|
||||||
|
// It contains:
|
||||||
|
// - reAuth: a function that takes the new credentials and returns an error if any.
|
||||||
|
// - onErr: a function that takes an error and handles it.
|
||||||
|
type ReAuthCredentialsListener struct {
|
||||||
|
reAuth func(credentials Credentials) error
|
||||||
|
onErr func(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnNext is called when the credentials are updated.
|
||||||
|
// It calls the reAuth function with the new credentials.
|
||||||
|
// If the reAuth function returns an error, it calls the onErr function with the error.
|
||||||
|
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
|
||||||
|
if c.reAuth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.reAuth(credentials)
|
||||||
|
if err != nil {
|
||||||
|
c.OnError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnError is called when an error occurs.
|
||||||
|
// It can be called from both the credentials provider and the reAuth function.
|
||||||
|
func (c *ReAuthCredentialsListener) OnError(err error) {
|
||||||
|
if c.onErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.onErr(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
|
||||||
|
// Implements the auth.CredentialsListener interface.
|
||||||
|
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
|
||||||
|
return &ReAuthCredentialsListener{
|
||||||
|
reAuth: reAuth,
|
||||||
|
onErr: onErr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
|
||||||
|
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
|
86
command_recorder_test.go
Normal file
86
command_recorder_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package redis_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// commandRecorder records the last N commands executed by a Redis client.
|
||||||
|
type commandRecorder struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
commands []string
|
||||||
|
maxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCommandRecorder creates a new command recorder with the specified maximum size.
|
||||||
|
func newCommandRecorder(maxSize int) *commandRecorder {
|
||||||
|
return &commandRecorder{
|
||||||
|
commands: make([]string, 0, maxSize),
|
||||||
|
maxSize: maxSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record adds a command to the recorder.
|
||||||
|
func (r *commandRecorder) Record(cmd string) {
|
||||||
|
cmd = strings.ToLower(cmd)
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
r.commands = append(r.commands, cmd)
|
||||||
|
if len(r.commands) > r.maxSize {
|
||||||
|
r.commands = r.commands[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastCommands returns a copy of the recorded commands.
|
||||||
|
func (r *commandRecorder) LastCommands() []string {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
return append([]string(nil), r.commands...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks if the recorder contains a specific command.
|
||||||
|
func (r *commandRecorder) Contains(cmd string) bool {
|
||||||
|
cmd = strings.ToLower(cmd)
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
for _, c := range r.commands {
|
||||||
|
if strings.Contains(c, cmd) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hook returns a Redis hook that records commands.
|
||||||
|
func (r *commandRecorder) Hook() redis.Hook {
|
||||||
|
return &commandHook{recorder: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
// commandHook implements the redis.Hook interface to record commands.
|
||||||
|
type commandHook struct {
|
||||||
|
recorder *commandRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||||
|
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||||
|
h.recorder.Record(cmd.String())
|
||||||
|
return next(ctx, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
h.recorder.Record(cmd.String())
|
||||||
|
}
|
||||||
|
return next(ctx, cmds)
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ package example_commands_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
@ -33,6 +34,7 @@ func ExampleClient_LPush_and_lrange() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(listSize)
|
fmt.Println(listSize)
|
||||||
|
time.Sleep(10 * time.Millisecond) // Simulate some delay
|
||||||
|
|
||||||
value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result()
|
value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -23,38 +23,47 @@ func (redisHook) DialHook(hook redis.DialHook) redis.DialHook {
|
|||||||
|
|
||||||
func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
|
func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
|
||||||
return func(ctx context.Context, cmd redis.Cmder) error {
|
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||||
fmt.Printf("starting processing: <%s>\n", cmd)
|
fmt.Printf("starting processing: <%v>\n", cmd.Args())
|
||||||
err := hook(ctx, cmd)
|
err := hook(ctx, cmd)
|
||||||
fmt.Printf("finished processing: <%s>\n", cmd)
|
fmt.Printf("finished processing: <%v>\n", cmd.Args())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
fmt.Printf("pipeline starting processing: %v\n", cmds)
|
names := make([]string, 0, len(cmds))
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
names = append(names, fmt.Sprintf("%v", cmd.Args()))
|
||||||
|
}
|
||||||
|
fmt.Printf("pipeline starting processing: %v\n", names)
|
||||||
err := hook(ctx, cmds)
|
err := hook(ctx, cmds)
|
||||||
fmt.Printf("pipeline finished processing: %v\n", cmds)
|
fmt.Printf("pipeline finished processing: %v\n", names)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_instrumentation() {
|
func Example_instrumentation() {
|
||||||
rdb := redis.NewClient(&redis.Options{
|
rdb := redis.NewClient(&redis.Options{
|
||||||
Addr: ":6379",
|
Addr: ":6379",
|
||||||
|
DisableIdentity: true,
|
||||||
})
|
})
|
||||||
rdb.AddHook(redisHook{})
|
rdb.AddHook(redisHook{})
|
||||||
|
|
||||||
rdb.Ping(ctx)
|
rdb.Ping(ctx)
|
||||||
// Output: starting processing: <ping: >
|
// Output:
|
||||||
|
// starting processing: <[ping]>
|
||||||
// dialing tcp :6379
|
// dialing tcp :6379
|
||||||
// finished dialing tcp :6379
|
// finished dialing tcp :6379
|
||||||
// finished processing: <ping: PONG>
|
// starting processing: <[hello 3]>
|
||||||
|
// finished processing: <[hello 3]>
|
||||||
|
// finished processing: <[ping]>
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExamplePipeline_instrumentation() {
|
func ExamplePipeline_instrumentation() {
|
||||||
rdb := redis.NewClient(&redis.Options{
|
rdb := redis.NewClient(&redis.Options{
|
||||||
Addr: ":6379",
|
Addr: ":6379",
|
||||||
|
DisableIdentity: true,
|
||||||
})
|
})
|
||||||
rdb.AddHook(redisHook{})
|
rdb.AddHook(redisHook{})
|
||||||
|
|
||||||
@ -63,15 +72,19 @@ func ExamplePipeline_instrumentation() {
|
|||||||
pipe.Ping(ctx)
|
pipe.Ping(ctx)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
// Output: pipeline starting processing: [ping: ping: ]
|
// Output:
|
||||||
|
// pipeline starting processing: [[ping] [ping]]
|
||||||
// dialing tcp :6379
|
// dialing tcp :6379
|
||||||
// finished dialing tcp :6379
|
// finished dialing tcp :6379
|
||||||
// pipeline finished processing: [ping: PONG ping: PONG]
|
// starting processing: <[hello 3]>
|
||||||
|
// finished processing: <[hello 3]>
|
||||||
|
// pipeline finished processing: [[ping] [ping]]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleClient_Watch_instrumentation() {
|
func ExampleClient_Watch_instrumentation() {
|
||||||
rdb := redis.NewClient(&redis.Options{
|
rdb := redis.NewClient(&redis.Options{
|
||||||
Addr: ":6379",
|
Addr: ":6379",
|
||||||
|
DisableIdentity: true,
|
||||||
})
|
})
|
||||||
rdb.AddHook(redisHook{})
|
rdb.AddHook(redisHook{})
|
||||||
|
|
||||||
@ -81,14 +94,16 @@ func ExampleClient_Watch_instrumentation() {
|
|||||||
return nil
|
return nil
|
||||||
}, "foo")
|
}, "foo")
|
||||||
// Output:
|
// Output:
|
||||||
// starting processing: <watch foo: >
|
// starting processing: <[watch foo]>
|
||||||
// dialing tcp :6379
|
// dialing tcp :6379
|
||||||
// finished dialing tcp :6379
|
// finished dialing tcp :6379
|
||||||
// finished processing: <watch foo: OK>
|
// starting processing: <[hello 3]>
|
||||||
// starting processing: <ping: >
|
// finished processing: <[hello 3]>
|
||||||
// finished processing: <ping: PONG>
|
// finished processing: <[watch foo]>
|
||||||
// starting processing: <ping: >
|
// starting processing: <[ping]>
|
||||||
// finished processing: <ping: PONG>
|
// finished processing: <[ping]>
|
||||||
// starting processing: <unwatch: >
|
// starting processing: <[ping]>
|
||||||
// finished processing: <unwatch: OK>
|
// finished processing: <[ping]>
|
||||||
|
// starting processing: <[unwatch]>
|
||||||
|
// finished processing: <[unwatch]>
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,8 @@ type Conn struct {
|
|||||||
Inited bool
|
Inited bool
|
||||||
pooled bool
|
pooled bool
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
|
|
||||||
|
onClose func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(netConn net.Conn) *Conn {
|
func NewConn(netConn net.Conn) *Conn {
|
||||||
@ -46,6 +48,10 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
|||||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cn *Conn) SetOnClose(fn func() error) {
|
||||||
|
cn.onClose = fn
|
||||||
|
}
|
||||||
|
|
||||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||||
cn.netConn = netConn
|
cn.netConn = netConn
|
||||||
cn.rd.Reset(netConn)
|
cn.rd.Reset(netConn)
|
||||||
@ -95,6 +101,10 @@ func (cn *Conn) WithWriter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cn *Conn) Close() error {
|
func (cn *Conn) Close() error {
|
||||||
|
if cn.onClose != nil {
|
||||||
|
// ignore error
|
||||||
|
_ = cn.onClose()
|
||||||
|
}
|
||||||
return cn.netConn.Close()
|
return cn.netConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewClient: func(opt *Options) *Client {
|
NewClient: func(opt *Options) *Client {
|
||||||
c := NewClient(opt)
|
c := NewClient(opt)
|
||||||
c.baseClient.onClose = func() error {
|
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
|
||||||
closeCounter.increment(opt.Addr)
|
closeCounter.increment(opt.Addr)
|
||||||
return nil
|
return nil
|
||||||
}
|
})
|
||||||
return c
|
return c
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
createCounter.increment(opt.Addr)
|
createCounter.increment(opt.Addr)
|
||||||
c := NewClient(opt)
|
c := NewClient(opt)
|
||||||
c.baseClient.onClose = func() error {
|
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
|
||||||
closeCounter.increment(opt.Addr)
|
closeCounter.increment(opt.Addr)
|
||||||
return nil
|
return nil
|
||||||
}
|
})
|
||||||
return c
|
return c
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
129
options.go
129
options.go
@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/auth"
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,10 +30,13 @@ type Limiter interface {
|
|||||||
|
|
||||||
// Options keeps the settings to set up redis connection.
|
// Options keeps the settings to set up redis connection.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// The network type, either tcp or unix.
|
|
||||||
// Default is tcp.
|
// Network type, either tcp or unix.
|
||||||
|
//
|
||||||
|
// default: is tcp.
|
||||||
Network string
|
Network string
|
||||||
// host:port address.
|
|
||||||
|
// Addr is the address formated as host:port
|
||||||
Addr string
|
Addr string
|
||||||
|
|
||||||
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
|
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
|
||||||
@ -46,17 +50,21 @@ type Options struct {
|
|||||||
OnConnect func(ctx context.Context, cn *Conn) error
|
OnConnect func(ctx context.Context, cn *Conn) error
|
||||||
|
|
||||||
// Protocol 2 or 3. Use the version to negotiate RESP version with redis-server.
|
// Protocol 2 or 3. Use the version to negotiate RESP version with redis-server.
|
||||||
// Default is 3.
|
//
|
||||||
|
// default: 3.
|
||||||
Protocol int
|
Protocol int
|
||||||
// Use the specified Username to authenticate the current connection
|
|
||||||
|
// Username is used to authenticate the current connection
|
||||||
// with one of the connections defined in the ACL list when connecting
|
// with one of the connections defined in the ACL list when connecting
|
||||||
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
|
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
|
||||||
Username string
|
Username string
|
||||||
// Optional password. Must match the password specified in the
|
|
||||||
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
// Password is an optional password. Must match the password specified in the
|
||||||
|
// `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
||||||
// or the User Password when connecting to a Redis 6.0 instance, or greater,
|
// or the User Password when connecting to a Redis 6.0 instance, or greater,
|
||||||
// that is using the Redis ACL system.
|
// that is using the Redis ACL system.
|
||||||
Password string
|
Password string
|
||||||
|
|
||||||
// CredentialsProvider allows the username and password to be updated
|
// CredentialsProvider allows the username and password to be updated
|
||||||
// before reconnecting. It should return the current username and password.
|
// before reconnecting. It should return the current username and password.
|
||||||
CredentialsProvider func() (username string, password string)
|
CredentialsProvider func() (username string, password string)
|
||||||
@ -67,85 +75,126 @@ type Options struct {
|
|||||||
// There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
|
// There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
|
||||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||||
|
|
||||||
// Database to be selected after connecting to the server.
|
// StreamingCredentialsProvider is used to retrieve the credentials
|
||||||
|
// for the connection from an external source. Those credentials may change
|
||||||
|
// during the connection lifetime. This is useful for managed identity
|
||||||
|
// scenarios where the credentials are retrieved from an external source.
|
||||||
|
//
|
||||||
|
// Currently, this is a placeholder for the future implementation.
|
||||||
|
StreamingCredentialsProvider auth.StreamingCredentialsProvider
|
||||||
|
|
||||||
|
// DB is the database to be selected after connecting to the server.
|
||||||
DB int
|
DB int
|
||||||
|
|
||||||
// Maximum number of retries before giving up.
|
// MaxRetries is the maximum number of retries before giving up.
|
||||||
// Default is 3 retries; -1 (not 0) disables retries.
|
// -1 (not 0) disables retries.
|
||||||
|
//
|
||||||
|
// default: 3 retries
|
||||||
MaxRetries int
|
MaxRetries int
|
||||||
// Minimum backoff between each retry.
|
|
||||||
// Default is 8 milliseconds; -1 disables backoff.
|
// MinRetryBackoff is the minimum backoff between each retry.
|
||||||
|
// -1 disables backoff.
|
||||||
|
//
|
||||||
|
// default: 8 milliseconds
|
||||||
MinRetryBackoff time.Duration
|
MinRetryBackoff time.Duration
|
||||||
// Maximum backoff between each retry.
|
|
||||||
// Default is 512 milliseconds; -1 disables backoff.
|
// MaxRetryBackoff is the maximum backoff between each retry.
|
||||||
|
// -1 disables backoff.
|
||||||
|
// default: 512 milliseconds;
|
||||||
MaxRetryBackoff time.Duration
|
MaxRetryBackoff time.Duration
|
||||||
|
|
||||||
// Dial timeout for establishing new connections.
|
// DialTimeout for establishing new connections.
|
||||||
// Default is 5 seconds.
|
//
|
||||||
|
// default: 5 seconds
|
||||||
DialTimeout time.Duration
|
DialTimeout time.Duration
|
||||||
// Timeout for socket reads. If reached, commands will fail
|
|
||||||
|
// ReadTimeout for socket reads. If reached, commands will fail
|
||||||
// with a timeout instead of blocking. Supported values:
|
// with a timeout instead of blocking. Supported values:
|
||||||
// - `0` - default timeout (3 seconds).
|
//
|
||||||
// - `-1` - no timeout (block indefinitely).
|
// - `-1` - no timeout (block indefinitely).
|
||||||
// - `-2` - disables SetReadDeadline calls completely.
|
// - `-2` - disables SetReadDeadline calls completely.
|
||||||
|
//
|
||||||
|
// default: 3 seconds
|
||||||
ReadTimeout time.Duration
|
ReadTimeout time.Duration
|
||||||
// Timeout for socket writes. If reached, commands will fail
|
|
||||||
|
// WriteTimeout for socket writes. If reached, commands will fail
|
||||||
// with a timeout instead of blocking. Supported values:
|
// with a timeout instead of blocking. Supported values:
|
||||||
// - `0` - default timeout (3 seconds).
|
//
|
||||||
// - `-1` - no timeout (block indefinitely).
|
// - `-1` - no timeout (block indefinitely).
|
||||||
// - `-2` - disables SetWriteDeadline calls completely.
|
// - `-2` - disables SetWriteDeadline calls completely.
|
||||||
|
//
|
||||||
|
// default: 3 seconds
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
// ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines.
|
// ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines.
|
||||||
// See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts
|
// See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts
|
||||||
ContextTimeoutEnabled bool
|
ContextTimeoutEnabled bool
|
||||||
|
|
||||||
// Type of connection pool.
|
// PoolFIFO type of connection pool.
|
||||||
// true for FIFO pool, false for LIFO pool.
|
//
|
||||||
|
// - true for FIFO pool
|
||||||
|
// - false for LIFO pool.
|
||||||
|
//
|
||||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||||
// but it helps closing idle connections faster reducing the pool size.
|
// but it helps closing idle connections faster reducing the pool size.
|
||||||
PoolFIFO bool
|
PoolFIFO bool
|
||||||
// Base number of socket connections.
|
|
||||||
|
// PoolSize is the base number of socket connections.
|
||||||
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
|
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
|
||||||
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
|
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
|
||||||
// you can limit it through MaxActiveConns
|
// you can limit it through MaxActiveConns
|
||||||
|
//
|
||||||
|
// default: 10 * runtime.GOMAXPROCS(0)
|
||||||
PoolSize int
|
PoolSize int
|
||||||
// Amount of time client waits for connection if all connections
|
|
||||||
|
// PoolTimeout is the amount of time client waits for connection if all connections
|
||||||
// are busy before returning an error.
|
// are busy before returning an error.
|
||||||
// Default is ReadTimeout + 1 second.
|
//
|
||||||
|
// default: ReadTimeout + 1 second
|
||||||
PoolTimeout time.Duration
|
PoolTimeout time.Duration
|
||||||
// Minimum number of idle connections which is useful when establishing
|
|
||||||
// new connection is slow.
|
// MinIdleConns is the minimum number of idle connections which is useful when establishing
|
||||||
// Default is 0. the idle connections are not closed by default.
|
// new connection is slow. The idle connections are not closed by default.
|
||||||
|
//
|
||||||
|
// default: 0
|
||||||
MinIdleConns int
|
MinIdleConns int
|
||||||
// Maximum number of idle connections.
|
|
||||||
// Default is 0. the idle connections are not closed by default.
|
// MaxIdleConns is the maximum number of idle connections.
|
||||||
|
// The idle connections are not closed by default.
|
||||||
|
//
|
||||||
|
// default: 0
|
||||||
MaxIdleConns int
|
MaxIdleConns int
|
||||||
// Maximum number of connections allocated by the pool at a given time.
|
|
||||||
|
// MaxActiveConns is the maximum number of connections allocated by the pool at a given time.
|
||||||
// When zero, there is no limit on the number of connections in the pool.
|
// When zero, there is no limit on the number of connections in the pool.
|
||||||
|
// If the pool is full, the next call to Get() will block until a connection is released.
|
||||||
MaxActiveConns int
|
MaxActiveConns int
|
||||||
|
|
||||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
|
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
|
||||||
// Should be less than server's timeout.
|
// Should be less than server's timeout.
|
||||||
//
|
//
|
||||||
// Expired connections may be closed lazily before reuse.
|
// Expired connections may be closed lazily before reuse.
|
||||||
// If d <= 0, connections are not closed due to a connection's idle time.
|
// If d <= 0, connections are not closed due to a connection's idle time.
|
||||||
|
// -1 disables idle timeout check.
|
||||||
//
|
//
|
||||||
// Default is 30 minutes. -1 disables idle timeout check.
|
// default: 30 minutes
|
||||||
ConnMaxIdleTime time.Duration
|
ConnMaxIdleTime time.Duration
|
||||||
|
|
||||||
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
|
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
|
||||||
//
|
//
|
||||||
// Expired connections may be closed lazily before reuse.
|
// Expired connections may be closed lazily before reuse.
|
||||||
// If <= 0, connections are not closed due to a connection's age.
|
// If <= 0, connections are not closed due to a connection's age.
|
||||||
//
|
//
|
||||||
// Default is to not close idle connections.
|
// default: 0
|
||||||
ConnMaxLifetime time.Duration
|
ConnMaxLifetime time.Duration
|
||||||
|
|
||||||
// TLS Config to use. When set, TLS will be negotiated.
|
// TLSConfig to use. When set, TLS will be negotiated.
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
|
|
||||||
// Limiter interface used to implement circuit breaker or rate limiter.
|
// Limiter interface used to implement circuit breaker or rate limiter.
|
||||||
Limiter Limiter
|
Limiter Limiter
|
||||||
|
|
||||||
// Enables read only queries on slave/follower nodes.
|
// readOnly enables read only queries on slave/follower nodes.
|
||||||
readOnly bool
|
readOnly bool
|
||||||
|
|
||||||
// DisableIndentity - Disable set-lib on connect.
|
// DisableIndentity - Disable set-lib on connect.
|
||||||
@ -161,9 +210,11 @@ type Options struct {
|
|||||||
DisableIdentity bool
|
DisableIdentity bool
|
||||||
|
|
||||||
// Add suffix to client name. Default is empty.
|
// Add suffix to client name. Default is empty.
|
||||||
|
// IdentitySuffix - add suffix to client name.
|
||||||
IdentitySuffix string
|
IdentitySuffix string
|
||||||
|
|
||||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||||
|
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
|
||||||
UnstableResp3 bool
|
UnstableResp3 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/auth"
|
||||||
"github.com/redis/go-redis/v9/internal"
|
"github.com/redis/go-redis/v9/internal"
|
||||||
"github.com/redis/go-redis/v9/internal/hashtag"
|
"github.com/redis/go-redis/v9/internal/hashtag"
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
@ -66,11 +67,12 @@ type ClusterOptions struct {
|
|||||||
|
|
||||||
OnConnect func(ctx context.Context, cn *Conn) error
|
OnConnect func(ctx context.Context, cn *Conn) error
|
||||||
|
|
||||||
Protocol int
|
Protocol int
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
CredentialsProvider func() (username string, password string)
|
CredentialsProvider func() (username string, password string)
|
||||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||||
|
StreamingCredentialsProvider auth.StreamingCredentialsProvider
|
||||||
|
|
||||||
MaxRetries int
|
MaxRetries int
|
||||||
MinRetryBackoff time.Duration
|
MinRetryBackoff time.Duration
|
||||||
@ -292,11 +294,12 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
|||||||
Dialer: opt.Dialer,
|
Dialer: opt.Dialer,
|
||||||
OnConnect: opt.OnConnect,
|
OnConnect: opt.OnConnect,
|
||||||
|
|
||||||
Protocol: opt.Protocol,
|
Protocol: opt.Protocol,
|
||||||
Username: opt.Username,
|
Username: opt.Username,
|
||||||
Password: opt.Password,
|
Password: opt.Password,
|
||||||
CredentialsProvider: opt.CredentialsProvider,
|
CredentialsProvider: opt.CredentialsProvider,
|
||||||
CredentialsProviderContext: opt.CredentialsProviderContext,
|
CredentialsProviderContext: opt.CredentialsProviderContext,
|
||||||
|
StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
|
||||||
|
|
||||||
MaxRetries: opt.MaxRetries,
|
MaxRetries: opt.MaxRetries,
|
||||||
MinRetryBackoff: opt.MinRetryBackoff,
|
MinRetryBackoff: opt.MinRetryBackoff,
|
||||||
|
@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient(
|
|||||||
func (s *clusterScenario) Close() error {
|
func (s *clusterScenario) Close() error {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
for _, master := range s.masters() {
|
for _, master := range s.masters() {
|
||||||
|
if master == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
err := master.FlushAll(ctx).Err()
|
err := master.FlushAll(ctx).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("should CFCount", Label("cuckoo", "cfcount"), func() {
|
It("should CFCount", Label("cuckoo", "cfcount"), func() {
|
||||||
err := client.CFAdd(ctx, "testcf1", "item1").Err()
|
client.CFAdd(ctx, "testcf1", "item1")
|
||||||
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
|
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(cnt).To(BeEquivalentTo(int64(1)))
|
Expect(cnt).To(BeEquivalentTo(int64(1)))
|
||||||
@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
|||||||
NoCreate: true,
|
NoCreate: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
_, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
|
|
||||||
args = &redis.CFInsertOptions{
|
args = &redis.CFInsertOptions{
|
||||||
@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
|
|||||||
NoCreate: false,
|
NoCreate: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(len(result)).To(BeEquivalentTo(3))
|
Expect(len(result)).To(BeEquivalentTo(3))
|
||||||
})
|
})
|
||||||
|
152
redis.go
152
redis.go
@ -9,6 +9,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/auth"
|
||||||
"github.com/redis/go-redis/v9/internal"
|
"github.com/redis/go-redis/v9/internal"
|
||||||
"github.com/redis/go-redis/v9/internal/hscan"
|
"github.com/redis/go-redis/v9/internal/hscan"
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
@ -203,6 +204,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
|||||||
type baseClient struct {
|
type baseClient struct {
|
||||||
opt *Options
|
opt *Options
|
||||||
connPool pool.Pooler
|
connPool pool.Pooler
|
||||||
|
hooksMixin
|
||||||
|
|
||||||
onClose func() error // hook called when client is closed
|
onClose func() error // hook called when client is closed
|
||||||
}
|
}
|
||||||
@ -282,30 +284,107 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
|||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||||
|
return auth.NewReAuthCredentialsListener(
|
||||||
|
c.reAuthConnection(poolCn),
|
||||||
|
c.onAuthenticationErr(poolCn),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||||
|
return func(credentials auth.Credentials) error {
|
||||||
|
var err error
|
||||||
|
username, password := credentials.BasicAuth()
|
||||||
|
ctx := context.Background()
|
||||||
|
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||||
|
// hooksMixin are intentionally empty here
|
||||||
|
cn := newConn(c.opt, connPool, nil)
|
||||||
|
|
||||||
|
if username != "" {
|
||||||
|
err = cn.AuthACL(ctx, username, password).Err()
|
||||||
|
} else {
|
||||||
|
err = cn.Auth(ctx, password).Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||||
|
return func(err error) {
|
||||||
|
if err != nil {
|
||||||
|
if isBadConn(err, false, c.opt.Addr) {
|
||||||
|
// Close the connection to force a reconnection.
|
||||||
|
err := c.connPool.CloseConn(poolCn)
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err)
|
||||||
|
// try to close the network connection directly
|
||||||
|
// so that no resource is leaked
|
||||||
|
err := poolCn.Close()
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||||
|
onClose := c.onClose
|
||||||
|
return func() error {
|
||||||
|
var firstErr error
|
||||||
|
err := newOnClose()
|
||||||
|
// Even if we have an error we would like to execute the onClose hook
|
||||||
|
// if it exists. We will return the first error that occurred.
|
||||||
|
// This is to keep error handling consistent with the rest of the code.
|
||||||
|
if err != nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
if onClose != nil {
|
||||||
|
err = onClose()
|
||||||
|
if err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||||
if cn.Inited {
|
if cn.Inited {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
cn.Inited = true
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
username, password := c.opt.Username, c.opt.Password
|
cn.Inited = true
|
||||||
if c.opt.CredentialsProviderContext != nil {
|
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||||
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||||
return err
|
|
||||||
|
username, password := "", ""
|
||||||
|
if c.opt.StreamingCredentialsProvider != nil {
|
||||||
|
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||||
|
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||||
|
}
|
||||||
|
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||||
|
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
||||||
|
username, password = credentials.BasicAuth()
|
||||||
|
} else if c.opt.CredentialsProviderContext != nil {
|
||||||
|
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||||
}
|
}
|
||||||
} else if c.opt.CredentialsProvider != nil {
|
} else if c.opt.CredentialsProvider != nil {
|
||||||
username, password = c.opt.CredentialsProvider()
|
username, password = c.opt.CredentialsProvider()
|
||||||
|
} else if c.opt.Username != "" || c.opt.Password != "" {
|
||||||
|
username, password = c.opt.Username, c.opt.Password
|
||||||
}
|
}
|
||||||
|
|
||||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
|
||||||
conn := newConn(c.opt, connPool)
|
|
||||||
|
|
||||||
var auth bool
|
|
||||||
// for redis-server versions that do not support the HELLO command,
|
// for redis-server versions that do not support the HELLO command,
|
||||||
// RESP2 will continue to be used.
|
// RESP2 will continue to be used.
|
||||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||||
auth = true
|
// Authentication successful with HELLO command
|
||||||
} else if !isRedisError(err) {
|
} else if !isRedisError(err) {
|
||||||
// When the server responds with the RESP protocol and the result is not a normal
|
// When the server responds with the RESP protocol and the result is not a normal
|
||||||
// execution result of the HELLO command, we consider it to be an indication that
|
// execution result of the HELLO command, we consider it to be an indication that
|
||||||
@ -315,17 +394,19 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
// with different error string results for unsupported commands, making it
|
// with different error string results for unsupported commands, making it
|
||||||
// difficult to rely on error strings to determine all results.
|
// difficult to rely on error strings to determine all results.
|
||||||
return err
|
return err
|
||||||
|
} else if password != "" {
|
||||||
|
// Try legacy AUTH command if HELLO failed
|
||||||
|
if username != "" {
|
||||||
|
err = conn.AuthACL(ctx, username, password).Err()
|
||||||
|
} else {
|
||||||
|
err = conn.Auth(ctx, password).Err()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to authenticate: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||||
if !auth && password != "" {
|
|
||||||
if username != "" {
|
|
||||||
pipe.AuthACL(ctx, username, password)
|
|
||||||
} else {
|
|
||||||
pipe.Auth(ctx, password)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.opt.DB > 0 {
|
if c.opt.DB > 0 {
|
||||||
pipe.Select(ctx, c.opt.DB)
|
pipe.Select(ctx, c.opt.DB)
|
||||||
}
|
}
|
||||||
@ -341,7 +422,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||||
@ -363,6 +444,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
if c.opt.OnConnect != nil {
|
if c.opt.OnConnect != nil {
|
||||||
return c.opt.OnConnect(ctx, conn)
|
return c.opt.OnConnect(ctx, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -481,6 +563,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
|||||||
return c.opt.ReadTimeout
|
return c.opt.ReadTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// context returns the context for the current connection.
|
||||||
|
// If the context timeout is enabled, it returns the original context.
|
||||||
|
// Otherwise, it returns a new background context.
|
||||||
|
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||||
|
if c.opt.ContextTimeoutEnabled {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the client, releasing any open resources.
|
// Close closes the client, releasing any open resources.
|
||||||
//
|
//
|
||||||
// It is rare to Close a Client, as the Client is meant to be
|
// It is rare to Close a Client, as the Client is meant to be
|
||||||
@ -633,13 +725,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
|
||||||
if c.opt.ContextTimeoutEnabled {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
return context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
// Client is a Redis client representing a pool of zero or more underlying connections.
|
// Client is a Redis client representing a pool of zero or more underlying connections.
|
||||||
@ -650,7 +735,6 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
*baseClient
|
*baseClient
|
||||||
cmdable
|
cmdable
|
||||||
hooksMixin
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient returns a client to the Redis Server specified by Options.
|
// NewClient returns a client to the Redis Server specified by Options.
|
||||||
@ -689,7 +773,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Conn() *Conn {
|
func (c *Client) Conn() *Conn {
|
||||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
|
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do create a Cmd from the args and processes the cmd.
|
// Do create a Cmd from the args and processes the cmd.
|
||||||
@ -822,10 +906,12 @@ type Conn struct {
|
|||||||
baseClient
|
baseClient
|
||||||
cmdable
|
cmdable
|
||||||
statefulCmdable
|
statefulCmdable
|
||||||
hooksMixin
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
// newConn is a helper func to create a new Conn instance.
|
||||||
|
// the Conn instance is not thread-safe and should not be shared between goroutines.
|
||||||
|
// the parentHooks will be cloned, no need to clone before passing it.
|
||||||
|
func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
|
||||||
c := Conn{
|
c := Conn{
|
||||||
baseClient: baseClient{
|
baseClient: baseClient{
|
||||||
opt: opt,
|
opt: opt,
|
||||||
@ -833,6 +919,10 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if parentHooks != nil {
|
||||||
|
c.hooksMixin = parentHooks.clone()
|
||||||
|
}
|
||||||
|
|
||||||
c.cmdable = c.Process
|
c.cmdable = c.Process
|
||||||
c.statefulCmdable = c.Process
|
c.statefulCmdable = c.Process
|
||||||
c.initHooks(hooks{
|
c.initHooks(hooks{
|
||||||
|
169
redis_test.go
169
redis_test.go
@ -14,6 +14,7 @@ import (
|
|||||||
. "github.com/bsm/gomega"
|
. "github.com/bsm/gomega"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/redis/go-redis/v9/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type redisHookError struct{}
|
type redisHookError struct{}
|
||||||
@ -727,6 +728,174 @@ var _ = Describe("Dialer connection timeouts", func() {
|
|||||||
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
|
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var _ = Describe("Credentials Provider Priority", func() {
|
||||||
|
var client *redis.Client
|
||||||
|
var opt *redis.Options
|
||||||
|
var recorder *commandRecorder
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
recorder = newCommandRecorder(10)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
if client != nil {
|
||||||
|
Expect(client.Close()).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use streaming provider when available", func() {
|
||||||
|
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
|
||||||
|
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||||
|
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
username, password := providerCreds.BasicAuth()
|
||||||
|
return username, password
|
||||||
|
},
|
||||||
|
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||||
|
username, password := ctxCreds.BasicAuth()
|
||||||
|
return username, password, nil
|
||||||
|
},
|
||||||
|
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||||
|
credentials: streamingCreds,
|
||||||
|
updates: make(chan auth.Credentials, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use context provider when streaming provider is not available", func() {
|
||||||
|
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||||
|
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
username, password := providerCreds.BasicAuth()
|
||||||
|
return username, password
|
||||||
|
},
|
||||||
|
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||||
|
username, password := ctxCreds.BasicAuth()
|
||||||
|
return username, password, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use regular provider when streaming and context providers are not available", func() {
|
||||||
|
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
CredentialsProvider: func() (string, string) {
|
||||||
|
username, password := providerCreds.BasicAuth()
|
||||||
|
return username, password
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use username/password fields when no providers are set", func() {
|
||||||
|
opt = &redis.Options{
|
||||||
|
Username: "field_user",
|
||||||
|
Password: "field_pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use empty credentials when nothing is set", func() {
|
||||||
|
opt = &redis.Options{}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// no pass, ok
|
||||||
|
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH")).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle credential updates from streaming provider", func() {
|
||||||
|
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
|
||||||
|
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
|
||||||
|
updatesChan := make(chan auth.Credentials, 1)
|
||||||
|
|
||||||
|
opt = &redis.Options{
|
||||||
|
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||||
|
credentials: initialCreds,
|
||||||
|
updates: updatesChan,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = redis.NewClient(opt)
|
||||||
|
client.AddHook(recorder.Hook())
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
|
||||||
|
|
||||||
|
// Update credentials
|
||||||
|
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
|
||||||
|
// wrongpass
|
||||||
|
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||||
|
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
|
||||||
|
close(updatesChan)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
type mockStreamingProvider struct {
|
||||||
|
credentials auth.Credentials
|
||||||
|
err error
|
||||||
|
updates chan auth.Credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, nil, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start goroutine to handle updates
|
||||||
|
go func() {
|
||||||
|
for creds := range m.updates {
|
||||||
|
m.credentials = creds
|
||||||
|
listener.OnNext(creds)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return m.credentials, func() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// this is just a mock:
|
||||||
|
// allow multiple closes from multiple listeners
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("Client creation", func() {
|
var _ = Describe("Client creation", func() {
|
||||||
Context("simple client with nil options", func() {
|
Context("simple client with nil options", func() {
|
||||||
It("panics", func() {
|
It("panics", func() {
|
||||||
|
32
ring_test.go
32
ring_test.go
@ -357,13 +357,17 @@ var _ = Describe("Redis Ring", func() {
|
|||||||
ring.AddHook(&hook{
|
ring.AddHook(&hook{
|
||||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
Expect(cmds).To(HaveLen(1))
|
// skip the connection initialization
|
||||||
|
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||||
Expect(cmds[0].String()).To(Equal("ping: "))
|
Expect(cmds[0].String()).To(Equal("ping: "))
|
||||||
stack = append(stack, "ring.BeforeProcessPipeline")
|
stack = append(stack, "ring.BeforeProcessPipeline")
|
||||||
|
|
||||||
err := hook(ctx, cmds)
|
err := hook(ctx, cmds)
|
||||||
|
|
||||||
Expect(cmds).To(HaveLen(1))
|
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||||
Expect(cmds[0].String()).To(Equal("ping: PONG"))
|
Expect(cmds[0].String()).To(Equal("ping: PONG"))
|
||||||
stack = append(stack, "ring.AfterProcessPipeline")
|
stack = append(stack, "ring.AfterProcessPipeline")
|
||||||
|
|
||||||
@ -376,13 +380,17 @@ var _ = Describe("Redis Ring", func() {
|
|||||||
shard.AddHook(&hook{
|
shard.AddHook(&hook{
|
||||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
Expect(cmds).To(HaveLen(1))
|
// skip the connection initialization
|
||||||
|
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||||
Expect(cmds[0].String()).To(Equal("ping: "))
|
Expect(cmds[0].String()).To(Equal("ping: "))
|
||||||
stack = append(stack, "shard.BeforeProcessPipeline")
|
stack = append(stack, "shard.BeforeProcessPipeline")
|
||||||
|
|
||||||
err := hook(ctx, cmds)
|
err := hook(ctx, cmds)
|
||||||
|
|
||||||
Expect(cmds).To(HaveLen(1))
|
Expect(len(cmds)).To(BeNumerically(">", 0))
|
||||||
Expect(cmds[0].String()).To(Equal("ping: PONG"))
|
Expect(cmds[0].String()).To(Equal("ping: PONG"))
|
||||||
stack = append(stack, "shard.AfterProcessPipeline")
|
stack = append(stack, "shard.AfterProcessPipeline")
|
||||||
|
|
||||||
@ -416,14 +424,18 @@ var _ = Describe("Redis Ring", func() {
|
|||||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
// skip the connection initialization
|
||||||
|
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
Expect(cmds).To(HaveLen(3))
|
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||||
Expect(cmds[1].String()).To(Equal("ping: "))
|
Expect(cmds[1].String()).To(Equal("ping: "))
|
||||||
stack = append(stack, "ring.BeforeProcessPipeline")
|
stack = append(stack, "ring.BeforeProcessPipeline")
|
||||||
|
|
||||||
err := hook(ctx, cmds)
|
err := hook(ctx, cmds)
|
||||||
|
|
||||||
Expect(cmds).To(HaveLen(3))
|
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||||
Expect(cmds[1].String()).To(Equal("ping: PONG"))
|
Expect(cmds[1].String()).To(Equal("ping: PONG"))
|
||||||
stack = append(stack, "ring.AfterProcessPipeline")
|
stack = append(stack, "ring.AfterProcessPipeline")
|
||||||
|
|
||||||
@ -437,14 +449,18 @@ var _ = Describe("Redis Ring", func() {
|
|||||||
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
// skip the connection initialization
|
||||||
|
if cmds[0].Name() == "hello" || cmds[0].Name() == "client" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
Expect(cmds).To(HaveLen(3))
|
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||||
Expect(cmds[1].String()).To(Equal("ping: "))
|
Expect(cmds[1].String()).To(Equal("ping: "))
|
||||||
stack = append(stack, "shard.BeforeProcessPipeline")
|
stack = append(stack, "shard.BeforeProcessPipeline")
|
||||||
|
|
||||||
err := hook(ctx, cmds)
|
err := hook(ctx, cmds)
|
||||||
|
|
||||||
Expect(cmds).To(HaveLen(3))
|
Expect(len(cmds)).To(BeNumerically(">=", 3))
|
||||||
Expect(cmds[1].String()).To(Equal("ping: PONG"))
|
Expect(cmds[1].String()).To(Equal("ping: PONG"))
|
||||||
stack = append(stack, "shard.AfterProcessPipeline")
|
stack = append(stack, "shard.AfterProcessPipeline")
|
||||||
|
|
||||||
|
@ -404,7 +404,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
|||||||
|
|
||||||
connPool = newConnPool(opt, rdb.dialHook)
|
connPool = newConnPool(opt, rdb.dialHook)
|
||||||
rdb.connPool = connPool
|
rdb.connPool = connPool
|
||||||
rdb.onClose = failover.Close
|
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||||
|
|
||||||
failover.mu.Lock()
|
failover.mu.Lock()
|
||||||
failover.onFailover = func(ctx context.Context, addr string) {
|
failover.onFailover = func(ctx context.Context, addr string) {
|
||||||
@ -455,7 +455,6 @@ func masterReplicaDialer(
|
|||||||
// SentinelClient is a client for a Redis Sentinel.
|
// SentinelClient is a client for a Redis Sentinel.
|
||||||
type SentinelClient struct {
|
type SentinelClient struct {
|
||||||
*baseClient
|
*baseClient
|
||||||
hooksMixin
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSentinelClient(opt *Options) *SentinelClient {
|
func NewSentinelClient(opt *Options) *SentinelClient {
|
||||||
|
7
tx.go
7
tx.go
@ -19,16 +19,15 @@ type Tx struct {
|
|||||||
baseClient
|
baseClient
|
||||||
cmdable
|
cmdable
|
||||||
statefulCmdable
|
statefulCmdable
|
||||||
hooksMixin
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) newTx() *Tx {
|
func (c *Client) newTx() *Tx {
|
||||||
tx := Tx{
|
tx := Tx{
|
||||||
baseClient: baseClient{
|
baseClient: baseClient{
|
||||||
opt: c.opt,
|
opt: c.opt,
|
||||||
connPool: pool.NewStickyConnPool(c.connPool),
|
connPool: pool.NewStickyConnPool(c.connPool),
|
||||||
|
hooksMixin: c.hooksMixin.clone(),
|
||||||
},
|
},
|
||||||
hooksMixin: c.hooksMixin.clone(),
|
|
||||||
}
|
}
|
||||||
tx.init()
|
tx.init()
|
||||||
return &tx
|
return &tx
|
||||||
|
Reference in New Issue
Block a user