1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

Merge branch 'master' into ndyakov/token-based-auth

This commit is contained in:
Nedyalko Dyakov
2025-05-19 20:13:51 +03:00
committed by GitHub
8 changed files with 91 additions and 35 deletions

View File

@ -15,6 +15,13 @@ import (
// ErrClosed performs any operation on the closed client will return this error. // ErrClosed performs any operation on the closed client will return this error.
var ErrClosed = pool.ErrClosed var ErrClosed = pool.ErrClosed
// ErrPoolExhausted is returned from a pool connection method
// when the maximum number of database connections in the pool has been reached.
var ErrPoolExhausted = pool.ErrPoolExhausted
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
var ErrPoolTimeout = pool.ErrPoolTimeout
// HasErrorPrefix checks if the err is a Redis error and the message contains a prefix. // HasErrorPrefix checks if the err is a Redis error and the message contains a prefix.
func HasErrorPrefix(err error, prefix string) bool { func HasErrorPrefix(err error, prefix string) bool {
var rErr Error var rErr Error

View File

@ -11,8 +11,6 @@ import (
"github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/pool"
) )
var ErrPoolTimeout = pool.ErrPoolTimeout
func (c *baseClient) Pool() pool.Pooler { func (c *baseClient) Pool() pool.Pooler {
return c.connPool return c.connPool
} }

View File

@ -387,4 +387,33 @@ var _ = Describe("race", func() {
Expect(stats.WaitCount).To(Equal(uint32(1))) Expect(stats.WaitCount).To(Equal(uint32(1)))
Expect(stats.WaitDurationNs).To(BeNumerically("~", time.Second.Nanoseconds(), 100*time.Millisecond.Nanoseconds())) Expect(stats.WaitDurationNs).To(BeNumerically("~", time.Second.Nanoseconds(), 100*time.Millisecond.Nanoseconds()))
}) })
It("timeout", func() {
testPoolTimeout := 1 * time.Second
opt := &pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
// Artificial delay to force pool timeout
time.Sleep(3 * testPoolTimeout)
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolTimeout: testPoolTimeout,
}
p := pool.NewConnPool(opt)
stats := p.Stats()
Expect(stats.Timeouts).To(Equal(uint32(0)))
conn, err := p.Get(ctx)
Expect(err).NotTo(HaveOccurred())
_, err = p.Get(ctx)
Expect(err).To(MatchError(pool.ErrPoolTimeout))
p.Put(ctx, conn)
conn, err = p.Get(ctx)
Expect(err).NotTo(HaveOccurred())
stats = p.Stats()
Expect(stats.Timeouts).To(Equal(uint32(1)))
})
}) })

View File

@ -49,22 +49,7 @@ func isLower(s string) bool {
} }
func ReplaceSpaces(s string) string { func ReplaceSpaces(s string) string {
// Pre-allocate a builder with the same length as s to minimize allocations. return strings.ReplaceAll(s, " ", "-")
// This is a basic optimization; adjust the initial size based on your use case.
var builder strings.Builder
builder.Grow(len(s))
for _, char := range s {
if char == ' ' {
// Replace space with a hyphen.
builder.WriteRune('-')
} else {
// Copy the character as-is.
builder.WriteRune(char)
}
}
return builder.String()
} }
func GetAddr(addr string) string { func GetAddr(addr string) string {

View File

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"runtime"
"strings" "strings"
"testing" "testing"
@ -72,3 +73,36 @@ func TestGetAddr(t *testing.T) {
Expect(GetAddr("127")).To(Equal("")) Expect(GetAddr("127")).To(Equal(""))
}) })
} }
func BenchmarkReplaceSpaces(b *testing.B) {
version := runtime.Version()
for i := 0; i < b.N; i++ {
_ = ReplaceSpaces(version)
}
}
func ReplaceSpacesUseBuilder(s string) string {
// Pre-allocate a builder with the same length as s to minimize allocations.
// This is a basic optimization; adjust the initial size based on your use case.
var builder strings.Builder
builder.Grow(len(s))
for _, char := range s {
if char == ' ' {
// Replace space with a hyphen.
builder.WriteRune('-')
} else {
// Copy the character as-is.
builder.WriteRune(char)
}
}
return builder.String()
}
func BenchmarkReplaceSpacesUseBuilder(b *testing.B) {
version := runtime.Version()
for i := 0; i < b.N; i++ {
_ = ReplaceSpacesUseBuilder(version)
}
}

View File

@ -364,14 +364,14 @@ var _ = Describe("ClusterClient", func() {
It("select slot from args for GETKEYSINSLOT command", func() { It("select slot from args for GETKEYSINSLOT command", func() {
cmd := NewStringSliceCmd(ctx, "cluster", "getkeysinslot", 100, 200) cmd := NewStringSliceCmd(ctx, "cluster", "getkeysinslot", 100, 200)
slot := client.cmdSlot(context.Background(), cmd) slot := client.cmdSlot(cmd)
Expect(slot).To(Equal(100)) Expect(slot).To(Equal(100))
}) })
It("select slot from args for COUNTKEYSINSLOT command", func() { It("select slot from args for COUNTKEYSINSLOT command", func() {
cmd := NewStringSliceCmd(ctx, "cluster", "countkeysinslot", 100) cmd := NewStringSliceCmd(ctx, "cluster", "countkeysinslot", 100)
slot := client.cmdSlot(context.Background(), cmd) slot := client.cmdSlot(cmd)
Expect(slot).To(Equal(100)) Expect(slot).To(Equal(100))
}) })
}) })

View File

@ -984,7 +984,7 @@ func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error {
} }
func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
slot := c.cmdSlot(ctx, cmd) slot := c.cmdSlot(cmd)
var node *clusterNode var node *clusterNode
var moved bool var moved bool
var ask bool var ask bool
@ -1332,7 +1332,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) {
for _, cmd := range cmds { for _, cmd := range cmds {
slot := c.cmdSlot(ctx, cmd) slot := c.cmdSlot(cmd)
node, err := c.slotReadOnlyNode(state, slot) node, err := c.slotReadOnlyNode(state, slot)
if err != nil { if err != nil {
return err return err
@ -1343,7 +1343,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
} }
for _, cmd := range cmds { for _, cmd := range cmds {
slot := c.cmdSlot(ctx, cmd) slot := c.cmdSlot(cmd)
node, err := state.slotMasterNode(slot) node, err := state.slotMasterNode(slot)
if err != nil { if err != nil {
return err return err
@ -1543,7 +1543,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err
func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int][]Cmder { func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int][]Cmder {
cmdsMap := make(map[int][]Cmder) cmdsMap := make(map[int][]Cmder)
for _, cmd := range cmds { for _, cmd := range cmds {
slot := c.cmdSlot(ctx, cmd) slot := c.cmdSlot(cmd)
cmdsMap[slot] = append(cmdsMap[slot], cmd) cmdsMap[slot] = append(cmdsMap[slot], cmd)
} }
return cmdsMap return cmdsMap
@ -1572,7 +1572,7 @@ func (c *ClusterClient) processTxPipelineNode(
} }
func (c *ClusterClient) processTxPipelineNodeConn( func (c *ClusterClient) processTxPipelineNodeConn(
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error { ) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds) return writeCmds(wr, cmds)
@ -1861,7 +1861,7 @@ func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo {
return info return info
} }
func (c *ClusterClient) cmdSlot(ctx context.Context, cmd Cmder) int { func (c *ClusterClient) cmdSlot(cmd Cmder) int {
args := cmd.Args() args := cmd.Args()
if args[0] == "cluster" && (args[1] == "getkeysinslot" || args[1] == "countkeysinslot") { if args[0] == "cluster" && (args[1] == "getkeysinslot" || args[1] == "countkeysinslot") {
return args[2].(int) return args[2].(int)

21
ring.go
View File

@ -349,17 +349,16 @@ func (c *ringSharding) newRingShards(
return return
} }
// Warning: External exposure of `c.shards.list` may cause data races.
// So keep internal or implement deep copy if exposed.
func (c *ringSharding) List() []*ringShard { func (c *ringSharding) List() []*ringShard {
var list []*ringShard
c.mu.RLock() c.mu.RLock()
if !c.closed { defer c.mu.RUnlock()
list = make([]*ringShard, len(c.shards.list))
copy(list, c.shards.list)
}
c.mu.RUnlock()
return list if c.closed {
return nil
}
return c.shards.list
} }
func (c *ringSharding) Hash(key string) string { func (c *ringSharding) Hash(key string) string {
@ -423,6 +422,7 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) {
case <-ticker.C: case <-ticker.C:
var rebalance bool var rebalance bool
// note: `c.List()` return a shadow copy of `[]*ringShard`.
for _, shard := range c.List() { for _, shard := range c.List() {
err := shard.Client.Ping(ctx).Err() err := shard.Client.Ping(ctx).Err()
isUp := err == nil || err == pool.ErrPoolTimeout isUp := err == nil || err == pool.ErrPoolTimeout
@ -582,6 +582,7 @@ func (c *Ring) retryBackoff(attempt int) time.Duration {
// PoolStats returns accumulated connection pool stats. // PoolStats returns accumulated connection pool stats.
func (c *Ring) PoolStats() *PoolStats { func (c *Ring) PoolStats() *PoolStats {
// note: `c.List()` return a shadow copy of `[]*ringShard`.
shards := c.sharding.List() shards := c.sharding.List()
var acc PoolStats var acc PoolStats
for _, shard := range shards { for _, shard := range shards {
@ -651,6 +652,7 @@ func (c *Ring) ForEachShard(
ctx context.Context, ctx context.Context,
fn func(ctx context.Context, client *Client) error, fn func(ctx context.Context, client *Client) error,
) error { ) error {
// note: `c.List()` return a shadow copy of `[]*ringShard`.
shards := c.sharding.List() shards := c.sharding.List()
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, 1) errCh := make(chan error, 1)
@ -682,6 +684,7 @@ func (c *Ring) ForEachShard(
} }
func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
// note: `c.List()` return a shadow copy of `[]*ringShard`.
shards := c.sharding.List() shards := c.sharding.List()
var firstErr error var firstErr error
for _, shard := range shards { for _, shard := range shards {
@ -810,7 +813,7 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er
for _, key := range keys { for _, key := range keys {
if key != "" { if key != "" {
shard, err := c.sharding.GetByKey(hashtag.Key(key)) shard, err := c.sharding.GetByKey(key)
if err != nil { if err != nil {
return err return err
} }