1
0
mirror of https://github.com/redis/go-redis.git synced 2025-11-26 06:23:09 +03:00
Files
go-redis/redis_test.go
Nedyalko Dyakov 042610b79d fix(conn): conn to have state machine (#3559)
* wip

* wip, used and unusable states

* polish state machine

* correct handling OnPut

* better errors for tests, hook should work now

* fix linter

* improve reauth state management. fix tests

* Update internal/pool/conn.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/conn.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* better timeouts

* empty endpoint handoff case

* fix handoff state when queued for handoff

* try to detect the deadlock

* try to detect the deadlock x2

* delete should be called

* improve tests

* fix mark on uninitialized connection

* Update internal/pool/conn_state_test.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/conn_state_test.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/pool.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/conn_state.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/conn.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix error from copilot

* address copilot comment

* fix(pool): pool performance  (#3565)

* perf(pool): replace hookManager RWMutex with atomic.Pointer and add predefined state slices

- Replace hookManager RWMutex with atomic.Pointer for lock-free reads in hot paths
- Add predefined state slices to avoid allocations (validFromInUse, validFromCreatedOrIdle, etc.)
- Add Clone() method to PoolHookManager for atomic updates
- Update AddPoolHook/RemovePoolHook to use copy-on-write pattern
- Update all hookManager access points to use atomic Load()

Performance improvements:
- Eliminates RWMutex contention in Get/Put/Remove hot paths
- Reduces allocations by reusing predefined state slices
- Lock-free reads allow better CPU cache utilization

* perf(pool): eliminate mutex overhead in state machine hot path

The state machine was calling notifyWaiters() on EVERY Get/Put operation,
which acquired a mutex even when no waiters were present (the common case).

Fix: Use atomic waiterCount to check for waiters BEFORE acquiring mutex.
This eliminates mutex contention in the hot path (Get/Put operations).

Implementation:
- Added atomic.Int32 waiterCount field to ConnStateMachine
- Increment when adding waiter, decrement when removing
- Check waiterCount atomically before acquiring mutex in notifyWaiters()

Performance impact:
- Before: mutex lock/unlock on every Get/Put (even with no waiters)
- After: lock-free atomic check, only acquire mutex if waiters exist
- Expected improvement: ~30-50% for Get/Put operations

* perf(pool): use predefined state slices to eliminate allocations in hot path

The pool was creating new slice literals on EVERY Get/Put operation:
- popIdle(): []ConnState{StateCreated, StateIdle}
- putConn(): []ConnState{StateInUse}
- CompareAndSwapUsed(): []ConnState{StateIdle} and []ConnState{StateInUse}
- MarkUnusableForHandoff(): []ConnState{StateInUse, StateIdle, StateCreated}

These allocations were happening millions of times per second in the hot path.

Fix: Use predefined global slices defined in conn_state.go:
- validFromInUse
- validFromCreatedOrIdle
- validFromCreatedInUseOrIdle

Performance impact:
- Before: 4 slice allocations per Get/Put cycle
- After: 0 allocations (use predefined slices)
- Expected improvement: ~30-40% reduction in allocations and GC pressure

* perf(pool): optimize TryTransition to reduce atomic operations

Further optimize the hot path by:
1. Remove redundant GetState() call in the loop
2. Only check waiterCount after successful CAS (not before loop)
3. Inline the waiterCount check to avoid notifyWaiters() call overhead

This reduces atomic operations from 4-5 per Get/Put to 2-3:
- Before: GetState() + CAS + waiterCount.Load() + notifyWaiters mutex check
- After: CAS + waiterCount.Load() (only if CAS succeeds)

Performance impact:
- Eliminates 1-2 atomic operations per Get/Put
- Expected improvement: ~10-15% for Get/Put operations

* perf(pool): add fast path for Get/Put to match master performance

Introduced TryTransitionFast() for the hot path (Get/Put operations):
- Single CAS operation (same as master's atomic bool)
- No waiter notification overhead
- No loop through valid states
- No error allocation

Hot path flow:
1. popIdle(): Try IDLE → IN_USE (fast), fallback to CREATED → IN_USE
2. putConn(): Try IN_USE → IDLE (fast)

This matches master's performance while preserving state machine for:
- Background operations (handoff/reauth use UNUSABLE state)
- State validation (TryTransition still available)
- Waiter notification (AwaitAndTransition for blocking)

Performance comparison per Get/Put cycle:
- Master: 2 atomic CAS operations
- State machine (before): 5 atomic operations (2.5x slower)
- State machine (after): 2 atomic CAS operations (same as master!)

Expected improvement: Restore to baseline ~11,373 ops/sec

* combine cas

* fix linter

* try faster approach

* fast semaphore

* better inlining for hot path

* fix linter issues

* use new semaphore in auth as well

* linter should be happy now

* add comments

* Update internal/pool/conn_state.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* address comment

* slight reordering

* try to cache time if for non-critical calculation

* fix wrong benchmark

* add concurrent test

* fix benchmark report

* add additional expect to check output

* comment and variable rename

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* initConn sets IDLE state

- Handle unexpected conn state changes

* fix precision of time cache and usedAt

* allow e2e tests to run longer

* Fix broken initialization of idle connections

* optimize push notif

* 100ms -> 50ms

* use correct timer for last health check

* verify pass auth on conn creation

* fix assertion

* fix unsafe test

* fix benchmark test

* improve remove conn

* re doesn't support requirepass

* wait more in e2e test

* flaky test

* add missed method in interface

* fix test assertions

* silence logs and faster hooks manager

* address linter comment

* fix flaky test

* use read instad of control

* use pool size for semsize

* CAS instead of reading the state

* preallocate errors and states

* preallocate state slices

* fix flaky test

* fix fast semaphore that could have been starved

* try to fix the semaphore

* should properly notify the waiters

- this way a waiter that timesout at the same time
a releaser is releasing, won't throw token. the releaser
will fail to notify and will pick another waiter.

this hybrid approach should be faster than channels and maintains FIFO

* waiter may double-release (if closed/times out)

* priority of operations

* use simple approach of fifo waiters

* use simple channel based semaphores

* address linter and tests

* remove unused benchs

* change log message

* address pr comments

* address pr comments

* fix data race

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-11 17:38:29 +02:00

1051 lines
26 KiB
Go

package redis_test
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"sync"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)
type redisHookError struct{}
var _ redis.Hook = redisHookError{}
func (redisHookError) DialHook(hook redis.DialHook) redis.DialHook {
return hook
}
func (redisHookError) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
return errors.New("hook error")
}
}
func (redisHookError) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return hook
}
func TestHookError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
rdb.AddHook(redisHookError{})
err := rdb.Ping(ctx).Err()
if err == nil {
t.Fatalf("got nil, expected an error")
}
wanted := "hook error"
if err.Error() != wanted {
t.Fatalf(`got %q, wanted %q`, err, wanted)
}
}
//------------------------------------------------------------------------------
var _ = Describe("Client", func() {
var client *redis.Client
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
client.Close()
})
It("should Stringer", func() {
Expect(client.String()).To(Equal(fmt.Sprintf("Redis<:%s db:0>", redisPort)))
})
It("supports context", func() {
ctx, cancel := context.WithCancel(ctx)
cancel()
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled"))
})
It("supports WithTimeout", Label("NonRedisEnterprise"), func() {
err := client.ClientPause(ctx, time.Second).Err()
Expect(err).NotTo(HaveOccurred())
err = client.WithTimeout(10 * time.Millisecond).Ping(ctx).Err()
Expect(err).To(HaveOccurred())
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
})
It("do", func() {
val, err := client.Do(ctx, "ping").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
})
It("should ping", func() {
val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
})
It("should return pool stats", func() {
Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{}))
})
It("should support custom dialers", func() {
custom := redis.NewClient(&redis.Options{
Network: "tcp",
Addr: redisAddr,
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
},
})
val, err := custom.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
Expect(custom.Close()).NotTo(HaveOccurred())
})
It("should close", func() {
Expect(client.Close()).NotTo(HaveOccurred())
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("redis: client is closed"))
})
It("should close pubsub without closing the client", func() {
pubsub := client.Subscribe(ctx)
Expect(pubsub.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive(ctx)
Expect(err).To(MatchError("redis: client is closed"))
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should close Tx without closing the client", func() {
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
})
Expect(err).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should close pubsub when client is closed", func() {
pubsub := client.Subscribe(ctx)
Expect(client.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive(ctx)
Expect(err).To(MatchError("redis: client is closed"))
Expect(pubsub.Close()).NotTo(HaveOccurred())
})
It("should select DB", Label("NonRedisEnterprise"), func() {
db2 := redis.NewClient(&redis.Options{
Addr: redisAddr,
DB: 2,
})
Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
Expect(db2.Get(ctx, "db").Err()).To(Equal(redis.Nil))
Expect(db2.Set(ctx, "db", 2, 0).Err()).NotTo(HaveOccurred())
n, err := db2.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(2)))
Expect(client.Get(ctx, "db").Err()).To(Equal(redis.Nil))
Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
Expect(db2.Close()).NotTo(HaveOccurred())
})
It("should client setname", func() {
opt := redisOptions()
opt.ClientName = "hi"
db := redis.NewClient(opt)
defer func() {
Expect(db.Close()).NotTo(HaveOccurred())
}()
Expect(db.Ping(ctx).Err()).NotTo(HaveOccurred())
val, err := db.ClientList(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).Should(ContainSubstring("name=hi"))
})
It("should attempt to set client name in HELLO", func() {
opt := redisOptions()
opt.ClientName = "hi"
db := redis.NewClient(opt)
defer func() {
Expect(db.Close()).NotTo(HaveOccurred())
}()
// Client name should be already set on any successfully initialized connection
name, err := db.ClientGetName(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).Should(Equal("hi"))
// HELLO should be able to explicitly overwrite the client name
conn := db.Conn()
hello, err := conn.Hello(ctx, 3, "", "", "hi2").Result()
Expect(err).NotTo(HaveOccurred())
Expect(hello["proto"]).Should(Equal(int64(3)))
name, err = conn.ClientGetName(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).Should(Equal("hi2"))
err = conn.Close()
Expect(err).NotTo(HaveOccurred())
})
It("should client PROTO 2", func() {
opt := redisOptions()
opt.Protocol = 2
db := redis.NewClient(opt)
defer func() {
Expect(db.Close()).NotTo(HaveOccurred())
}()
val, err := db.Do(ctx, "HELLO").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).Should(ContainElements("proto", int64(2)))
})
It("should client PROTO 3", func() {
opt := redisOptions()
db := redis.NewClient(opt)
defer func() {
Expect(db.Close()).NotTo(HaveOccurred())
}()
val, err := db.Do(ctx, "HELLO").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).Should(HaveKeyWithValue("proto", int64(3)))
})
It("should initialize idle connections created by MinIdleConns", Label("NonRedisEnterprise"), func() {
opt := redisOptions()
passwrd := "asdf"
db0 := redis.NewClient(opt)
// set password
err := db0.Do(ctx, "CONFIG", "SET", "requirepass", passwrd).Err()
Expect(err).NotTo(HaveOccurred())
defer func() {
err = db0.Do(ctx, "CONFIG", "SET", "requirepass", "").Err()
Expect(err).NotTo(HaveOccurred())
Expect(db0.Close()).NotTo(HaveOccurred())
}()
opt.MinIdleConns = 5
opt.Password = passwrd
opt.DB = 1 // Set DB to require SELECT
db := redis.NewClient(opt)
defer func() {
Expect(db.Close()).NotTo(HaveOccurred())
}()
// Wait for minIdle connections to be created
time.Sleep(100 * time.Millisecond)
// Verify that idle connections were created
stats := db.PoolStats()
Expect(stats.IdleConns).To(BeNumerically(">=", 5))
// Now use these connections - they should be properly initialized
// If they're not initialized, we'll get NOAUTH or WRONGDB errors
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Each goroutine performs multiple operations
for j := 0; j < 5; j++ {
key := fmt.Sprintf("test_key_%d_%d", id, j)
err := db.Set(ctx, key, "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
val, err := db.Get(ctx, key).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("value"))
err = db.Del(ctx, key).Err()
Expect(err).NotTo(HaveOccurred())
}
}(i)
}
wg.Wait()
// Verify no errors occurred
Expect(db.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("processes custom commands", func() {
cmd := redis.NewCmd(ctx, "PING")
_ = client.Process(ctx, cmd)
// Flush buffers.
Expect(client.Echo(ctx, "hello").Err()).NotTo(HaveOccurred())
Expect(cmd.Err()).NotTo(HaveOccurred())
Expect(cmd.Val()).To(Equal("PONG"))
})
It("should retry command on network error", func() {
Expect(client.Close()).NotTo(HaveOccurred())
client = redis.NewClient(&redis.Options{
Addr: redisAddr,
MaxRetries: 1,
})
// Put bad connection in the pool.
cn, err := client.Pool().Get(ctx)
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
client.Pool().Put(ctx, cn)
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
})
It("should retry with backoff", func() {
clientNoRetry := redis.NewClient(&redis.Options{
Addr: ":1234",
MaxRetries: -1,
})
defer clientNoRetry.Close()
clientRetry := redis.NewClient(&redis.Options{
Addr: ":1234",
MaxRetries: 5,
MaxRetryBackoff: 128 * time.Millisecond,
})
defer clientRetry.Close()
startNoRetry := time.Now()
err := clientNoRetry.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
elapseNoRetry := time.Since(startNoRetry)
startRetry := time.Now()
err = clientRetry.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
elapseRetry := time.Since(startRetry)
Expect(elapseRetry).To(BeNumerically(">", elapseNoRetry, 10*time.Millisecond))
})
It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero())
// set cn.SetUsedAt(time) or time.Sleep(>1*time.Second)
// simulate the last time Conn was used
// time.Sleep() is not the standard sleep time
// link: https://go-review.googlesource.com/c/go/+/232298
cn.SetUsedAt(time.Now().Add(-1 * time.Second))
createdAt := cn.UsedAt()
client.Pool().Put(ctx, cn)
Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue())
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt().UnixNano()).To(BeNumerically(">", createdAt.UnixNano()))
Expect(cn.UsedAt().After(createdAt)).To(BeTrue())
})
It("should process command with special chars", func() {
set := client.Set(ctx, "key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
get := client.Get(ctx, "key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
})
It("should handle big vals", func() {
bigVal := bytes.Repeat([]byte{'*'}, 2e6)
err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
Expect(client.Close()).NotTo(HaveOccurred())
client = redis.NewClient(redisOptions())
got, err := client.Get(ctx, "key").Bytes()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
})
It("should set and scan time", func() {
tm := time.Now()
err := client.Set(ctx, "now", tm, 0).Err()
Expect(err).NotTo(HaveOccurred())
var tm2 time.Time
err = client.Get(ctx, "now").Scan(&tm2)
Expect(err).NotTo(HaveOccurred())
Expect(tm2).To(BeTemporally("==", tm))
})
It("should set and scan durations", func() {
duration := 10 * time.Minute
err := client.Set(ctx, "duration", duration, 0).Err()
Expect(err).NotTo(HaveOccurred())
var duration2 time.Duration
err = client.Get(ctx, "duration").Scan(&duration2)
Expect(err).NotTo(HaveOccurred())
Expect(duration2).To(Equal(duration))
})
It("should Conn", func() {
err := client.Conn().Get(ctx, "this-key-does-not-exist").Err()
Expect(err).To(Equal(redis.Nil))
})
It("should set and scan net.IP", func() {
ip := net.ParseIP("192.168.1.1")
err := client.Set(ctx, "ip", ip, 0).Err()
Expect(err).NotTo(HaveOccurred())
var ip2 net.IP
err = client.Get(ctx, "ip").Scan(&ip2)
Expect(err).NotTo(HaveOccurred())
Expect(ip2).To(Equal(ip))
})
})
var _ = Describe("Client timeout", func() {
var opt *redis.Options
var client *redis.Client
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
testTimeout := func() {
It("SETINFO timeouts", func() {
conn := client.Conn()
err := conn.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Ping timeouts", func() {
err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Subscribe timeouts", func() {
if opt.WriteTimeout == 0 {
return
}
pubsub := client.Subscribe(ctx)
defer pubsub.Close()
err := pubsub.Subscribe(ctx, "_")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx timeouts", func() {
err := client.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping(ctx).Err()
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
}
Context("read timeout", func() {
BeforeEach(func() {
opt = redisOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = redis.NewClient(opt)
})
testTimeout()
})
Context("write timeout", func() {
BeforeEach(func() {
opt = redisOptions()
opt.ReadTimeout = -1
opt.WriteTimeout = time.Nanosecond
client = redis.NewClient(opt)
})
testTimeout()
})
})
var _ = Describe("Client OnConnect", func() {
var client *redis.Client
BeforeEach(func() {
opt := redisOptions()
opt.DB = 0
opt.OnConnect = func(ctx context.Context, cn *redis.Conn) error {
return cn.ClientSetName(ctx, "on_connect").Err()
}
client = redis.NewClient(opt)
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("calls OnConnect", func() {
name, err := client.ClientGetName(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).To(Equal("on_connect"))
})
})
var _ = Describe("Client context cancellation", func() {
var opt *redis.Options
var client *redis.Client
BeforeEach(func() {
opt = redisOptions()
opt.ReadTimeout = -1
opt.WriteTimeout = -1
client = redis.NewClient(opt)
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("Blocking operation cancellation", func() {
ctx, cancel := context.WithCancel(ctx)
cancel()
err := client.BLPop(ctx, 1*time.Second, "test").Err()
Expect(err).To(HaveOccurred())
Expect(err).To(BeIdenticalTo(context.Canceled))
})
})
var _ = Describe("Conn", func() {
var client *redis.Client
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
It("TxPipeline", Label("NonRedisEnterprise"), func() {
tx := client.Conn().TxPipeline()
tx.SwapDB(ctx, 0, 2)
tx.SwapDB(ctx, 1, 0)
_, err := tx.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
})
})
var _ = Describe("Hook", func() {
var client *redis.Client
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
It("fifo", func() {
var res []string
client.AddHook(&hook{
processHook: func(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
res = append(res, "hook-1-process-start")
err := hook(ctx, cmd)
res = append(res, "hook-1-process-end")
return err
}
},
})
client.AddHook(&hook{
processHook: func(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
res = append(res, "hook-2-process-start")
err := hook(ctx, cmd)
res = append(res, "hook-2-process-end")
return err
}
},
})
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
Expect(res).To(Equal([]string{
"hook-1-process-start",
"hook-2-process-start",
"hook-2-process-end",
"hook-1-process-end",
}))
})
It("wrapped error in a hook", func() {
client.AddHook(&hook{
processHook: func(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
if err := hook(ctx, cmd); err != nil {
return fmt.Errorf("wrapped error: %w", err)
}
return nil
}
},
})
client.ScriptFlush(ctx)
script := redis.NewScript(`return 'Script and hook'`)
cmd := script.Run(ctx, client, nil)
Expect(cmd.Err()).NotTo(HaveOccurred())
Expect(cmd.Val()).To(Equal("Script and hook"))
})
})
var _ = Describe("Hook with MinIdleConns", func() {
var client *redis.Client
BeforeEach(func() {
options := redisOptions()
options.MinIdleConns = 1
client = redis.NewClient(options)
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
It("fifo", func() {
var res []string
client.AddHook(&hook{
processHook: func(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
res = append(res, "hook-1-process-start")
err := hook(ctx, cmd)
res = append(res, "hook-1-process-end")
return err
}
},
})
client.AddHook(&hook{
processHook: func(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
res = append(res, "hook-2-process-start")
err := hook(ctx, cmd)
res = append(res, "hook-2-process-end")
return err
}
},
})
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
Expect(res).To(Equal([]string{
"hook-1-process-start",
"hook-2-process-start",
"hook-2-process-end",
"hook-1-process-end",
}))
})
})
var _ = Describe("Dialer connection timeouts", func() {
var client *redis.Client
const dialSimulatedDelay = 1 * time.Second
BeforeEach(func() {
options := redisOptions()
options.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
// Simulated slow dialer.
// Note that the following sleep is deliberately not context-aware.
time.Sleep(dialSimulatedDelay)
return net.Dial("tcp", options.Addr)
}
options.MinIdleConns = 1
client = redis.NewClient(options)
})
AfterEach(func() {
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
It("does not contend on connection dial for concurrent commands", func() {
var wg sync.WaitGroup
const concurrency = 10
durations := make(chan time.Duration, concurrency)
errs := make(chan error, concurrency)
start := time.Now()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
start := time.Now()
err := client.Ping(ctx).Err()
durations <- time.Since(start)
errs <- err
}()
}
wg.Wait()
close(durations)
close(errs)
// All commands should eventually succeed, after acquiring a connection.
for err := range errs {
Expect(err).NotTo(HaveOccurred())
}
// Each individual command should complete within the simulated dial duration bound.
for duration := range durations {
Expect(duration).To(BeNumerically("<", 2*dialSimulatedDelay))
}
// Due to concurrent execution, the entire test suite should also complete within
// the same dial duration bound applied for individual commands.
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
})
})
var _ = Describe("Credentials Provider Priority", func() {
var client *redis.Client
var opt *redis.Options
var recorder *commandRecorder
BeforeEach(func() {
recorder = newCommandRecorder(10)
})
AfterEach(func() {
if client != nil {
Expect(client.Close()).NotTo(HaveOccurred())
}
})
It("should use streaming provider when available", func() {
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: streamingCreds,
updates: make(chan auth.Credentials, 1),
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
})
It("should use context provider when streaming provider is not available", func() {
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
})
It("should use regular provider when streaming and context providers are not available", func() {
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
})
It("should use username/password fields when no providers are set", func() {
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
})
It("should use empty credentials when nothing is set", func() {
opt = &redis.Options{}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// no pass, ok
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
Expect(recorder.Contains("AUTH")).To(BeFalse())
})
It("should handle credential updates from streaming provider", func() {
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
updatesChan := make(chan auth.Credentials, 1)
opt = &redis.Options{
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: initialCreds,
updates: updatesChan,
},
PoolSize: 1, // Force single connection to ensure reauth is tested
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
time.Sleep(10 * time.Millisecond)
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
// Update credentials
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
// Wait for reauth to complete and verify updated credentials are used
// We need to keep trying Ping until we see the updated AUTH command
// because the reauth happens asynchronously
Eventually(func() bool {
// wrongpass
_ = client.Ping(context.Background()).Err()
return recorder.Contains("AUTH updated_user")
}, "1s", "50ms").Should(BeTrue())
close(updatesChan)
})
})
type mockStreamingProvider struct {
mu sync.RWMutex
credentials auth.Credentials
err error
updates chan auth.Credentials
}
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}
if listener == nil {
return nil, nil, errors.New("listener cannot be nil")
}
// Create a done channel to stop the goroutine
done := make(chan struct{})
// Start goroutine to handle updates
go func() {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow panics to be caught without crashing
}
}()
for {
select {
case <-done:
return
case creds, ok := <-m.updates:
if !ok {
return
}
m.mu.Lock()
m.credentials = creds
m.mu.Unlock()
listener.OnNext(creds)
}
}
}()
m.mu.RLock()
currentCreds := m.credentials
m.mu.RUnlock()
return currentCreds, func() (err error) {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow multiple closes from multiple listeners
}
}()
close(done)
return
}, nil
}
var _ = Describe("Client creation", func() {
Context("simple client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewClient(nil)
}).To(Panic())
})
})
Context("cluster client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewClusterClient(nil)
}).To(Panic())
})
})
Context("ring client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewRing(nil)
}).To(Panic())
})
})
Context("universal client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewUniversalClient(nil)
}).To(Panic())
})
})
Context("failover client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewFailoverClient(nil)
}).To(Panic())
})
})
Context("failover cluster client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewFailoverClusterClient(nil)
}).To(Panic())
})
})
Context("sentinel client with nil options", func() {
It("panics", func() {
Expect(func() {
redis.NewSentinelClient(nil)
}).To(Panic())
})
})
})