diff --git a/error.go b/error.go index 6f47f7cf..8c811966 100644 --- a/error.go +++ b/error.go @@ -15,6 +15,13 @@ import ( // ErrClosed performs any operation on the closed client will return this error. var ErrClosed = pool.ErrClosed +// ErrPoolExhausted is returned from a pool connection method +// when the maximum number of database connections in the pool has been reached. +var ErrPoolExhausted = pool.ErrPoolExhausted + +// ErrPoolTimeout timed out waiting to get a connection from the connection pool. +var ErrPoolTimeout = pool.ErrPoolTimeout + // HasErrorPrefix checks if the err is a Redis error and the message contains a prefix. func HasErrorPrefix(err error, prefix string) bool { var rErr Error diff --git a/export_test.go b/export_test.go index 10d8f23c..c1b77683 100644 --- a/export_test.go +++ b/export_test.go @@ -11,8 +11,6 @@ import ( "github.com/redis/go-redis/v9/internal/pool" ) -var ErrPoolTimeout = pool.ErrPoolTimeout - func (c *baseClient) Pool() pool.Pooler { return c.connPool } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index d198ba54..0f366cc7 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -387,4 +387,33 @@ var _ = Describe("race", func() { Expect(stats.WaitCount).To(Equal(uint32(1))) Expect(stats.WaitDurationNs).To(BeNumerically("~", time.Second.Nanoseconds(), 100*time.Millisecond.Nanoseconds())) }) + + It("timeout", func() { + testPoolTimeout := 1 * time.Second + opt := &pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + // Artificial delay to force pool timeout + time.Sleep(3 * testPoolTimeout) + + return &net.TCPConn{}, nil + }, + PoolSize: 1, + PoolTimeout: testPoolTimeout, + } + p := pool.NewConnPool(opt) + + stats := p.Stats() + Expect(stats.Timeouts).To(Equal(uint32(0))) + + conn, err := p.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + _, err = p.Get(ctx) + Expect(err).To(MatchError(pool.ErrPoolTimeout)) + p.Put(ctx, conn) + conn, err = p.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + + stats = p.Stats() + Expect(stats.Timeouts).To(Equal(uint32(1))) + }) }) diff --git a/internal/util.go b/internal/util.go index cc1bff24..f77775ff 100644 --- a/internal/util.go +++ b/internal/util.go @@ -49,22 +49,7 @@ func isLower(s string) bool { } func ReplaceSpaces(s string) string { - // Pre-allocate a builder with the same length as s to minimize allocations. - // This is a basic optimization; adjust the initial size based on your use case. - var builder strings.Builder - builder.Grow(len(s)) - - for _, char := range s { - if char == ' ' { - // Replace space with a hyphen. - builder.WriteRune('-') - } else { - // Copy the character as-is. - builder.WriteRune(char) - } - } - - return builder.String() + return strings.ReplaceAll(s, " ", "-") } func GetAddr(addr string) string { diff --git a/internal/util_test.go b/internal/util_test.go index 57f7f9fa..0bc46646 100644 --- a/internal/util_test.go +++ b/internal/util_test.go @@ -1,6 +1,7 @@ package internal import ( + "runtime" "strings" "testing" @@ -72,3 +73,36 @@ func TestGetAddr(t *testing.T) { Expect(GetAddr("127")).To(Equal("")) }) } + +func BenchmarkReplaceSpaces(b *testing.B) { + version := runtime.Version() + for i := 0; i < b.N; i++ { + _ = ReplaceSpaces(version) + } +} + +func ReplaceSpacesUseBuilder(s string) string { + // Pre-allocate a builder with the same length as s to minimize allocations. + // This is a basic optimization; adjust the initial size based on your use case. + var builder strings.Builder + builder.Grow(len(s)) + + for _, char := range s { + if char == ' ' { + // Replace space with a hyphen. + builder.WriteRune('-') + } else { + // Copy the character as-is. + builder.WriteRune(char) + } + } + + return builder.String() +} + +func BenchmarkReplaceSpacesUseBuilder(b *testing.B) { + version := runtime.Version() + for i := 0; i < b.N; i++ { + _ = ReplaceSpacesUseBuilder(version) + } +} diff --git a/internal_test.go b/internal_test.go index a61b5c02..8f1f1f31 100644 --- a/internal_test.go +++ b/internal_test.go @@ -364,14 +364,14 @@ var _ = Describe("ClusterClient", func() { It("select slot from args for GETKEYSINSLOT command", func() { cmd := NewStringSliceCmd(ctx, "cluster", "getkeysinslot", 100, 200) - slot := client.cmdSlot(context.Background(), cmd) + slot := client.cmdSlot(cmd) Expect(slot).To(Equal(100)) }) It("select slot from args for COUNTKEYSINSLOT command", func() { cmd := NewStringSliceCmd(ctx, "cluster", "countkeysinslot", 100) - slot := client.cmdSlot(context.Background(), cmd) + slot := client.cmdSlot(cmd) Expect(slot).To(Equal(100)) }) }) diff --git a/osscluster.go b/osscluster.go index 0ccb087d..5dbcc936 100644 --- a/osscluster.go +++ b/osscluster.go @@ -984,7 +984,7 @@ func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { } func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd) var node *clusterNode var moved bool var ask bool @@ -1332,7 +1332,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd) node, err := c.slotReadOnlyNode(state, slot) if err != nil { return err @@ -1343,7 +1343,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd) node, err := state.slotMasterNode(slot) if err != nil { return err @@ -1543,7 +1543,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int][]Cmder { cmdsMap := make(map[int][]Cmder) for _, cmd := range cmds { - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd) cmdsMap[slot] = append(cmdsMap[slot], cmd) } return cmdsMap @@ -1572,7 +1572,7 @@ func (c *ClusterClient) processTxPipelineNode( } func (c *ClusterClient) processTxPipelineNodeConn( - ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, + ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ) error { if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) @@ -1861,7 +1861,7 @@ func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { return info } -func (c *ClusterClient) cmdSlot(ctx context.Context, cmd Cmder) int { +func (c *ClusterClient) cmdSlot(cmd Cmder) int { args := cmd.Args() if args[0] == "cluster" && (args[1] == "getkeysinslot" || args[1] == "countkeysinslot") { return args[2].(int) diff --git a/ring.go b/ring.go index 555ea2a1..ab3d0626 100644 --- a/ring.go +++ b/ring.go @@ -349,17 +349,16 @@ func (c *ringSharding) newRingShards( return } +// Warning: External exposure of `c.shards.list` may cause data races. +// So keep internal or implement deep copy if exposed. func (c *ringSharding) List() []*ringShard { - var list []*ringShard - c.mu.RLock() - if !c.closed { - list = make([]*ringShard, len(c.shards.list)) - copy(list, c.shards.list) - } - c.mu.RUnlock() + defer c.mu.RUnlock() - return list + if c.closed { + return nil + } + return c.shards.list } func (c *ringSharding) Hash(key string) string { @@ -423,6 +422,7 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { case <-ticker.C: var rebalance bool + // note: `c.List()` return a shadow copy of `[]*ringShard`. for _, shard := range c.List() { err := shard.Client.Ping(ctx).Err() isUp := err == nil || err == pool.ErrPoolTimeout @@ -582,6 +582,7 @@ func (c *Ring) retryBackoff(attempt int) time.Duration { // PoolStats returns accumulated connection pool stats. func (c *Ring) PoolStats() *PoolStats { + // note: `c.List()` return a shadow copy of `[]*ringShard`. shards := c.sharding.List() var acc PoolStats for _, shard := range shards { @@ -651,6 +652,7 @@ func (c *Ring) ForEachShard( ctx context.Context, fn func(ctx context.Context, client *Client) error, ) error { + // note: `c.List()` return a shadow copy of `[]*ringShard`. shards := c.sharding.List() var wg sync.WaitGroup errCh := make(chan error, 1) @@ -682,6 +684,7 @@ func (c *Ring) ForEachShard( } func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { + // note: `c.List()` return a shadow copy of `[]*ringShard`. shards := c.sharding.List() var firstErr error for _, shard := range shards { @@ -810,7 +813,7 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er for _, key := range keys { if key != "" { - shard, err := c.sharding.GetByKey(hashtag.Key(key)) + shard, err := c.sharding.GetByKey(key) if err != nil { return err }