From 212095540d9323dc250eea8dfc11e6b8a958f6bb Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 20 Nov 2025 11:01:57 +0200 Subject: [PATCH] wip --- osscluster.go | 19 ++++- osscluster_test.go | 186 ++++++++++++++++++++++++++++++--------------- 2 files changed, 144 insertions(+), 61 deletions(-) diff --git a/osscluster.go b/osscluster.go index 7925d2c6..999ae063 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1037,7 +1037,6 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { pipeline: c.processPipeline, txPipeline: c.processTxPipeline, }) - return c } @@ -1046,6 +1045,24 @@ func (c *ClusterClient) Options() *ClusterOptions { return c.opt } +// AddHook adds a hook to the client. +func (c *ClusterClient) AddHook(h Hook) { + // Add hook only to nodes, not to the cluster client itself. + // This prevents hooks from being called twice (once at cluster level, once at node level). + // The cluster client delegates all commands to nodes, so hooks on nodes will be called. + + if err := c.ForEachShard(context.Background(), func(ctx context.Context, node *Client) error { + node.AddHook(h) + return nil + }); err != nil { + return + } + + c.nodes.OnNewNode(func(rdb *Client) { + rdb.AddHook(h) + }) +} + // ReloadState reloads cluster state. If available it calls ClusterSlots func // to get cluster slots information. func (c *ClusterClient) ReloadState(ctx context.Context) { diff --git a/osscluster_test.go b/osscluster_test.go index 3659ec65..b63bb04e 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -1188,7 +1188,9 @@ var _ = Describe("ClusterClient", func() { var stack []string - clusterHook := &hook{ + // AddHook now only adds to nodes, not to cluster client itself + // This prevents hooks from being called twice + firstHook := &hook{ processHook: func(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { select { @@ -1198,20 +1200,20 @@ var _ = Describe("ClusterClient", func() { } Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcess") + stack = append(stack, "hook1.BeforeProcess") err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcess") + stack = append(stack, "hook1.AfterProcess") return err } }, } - client.AddHook(clusterHook) + client.AddHook(firstHook) - nodeHook := &hook{ + secondHook := &hook{ processHook: func(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { select { @@ -1221,30 +1223,27 @@ var _ = Describe("ClusterClient", func() { } Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcess") + stack = append(stack, "hook2.BeforeProcess") err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcess") + stack = append(stack, "hook2.AfterProcess") return err } }, } - - _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { - node.AddHook(nodeHook) - return nil - }) + client.AddHook(secondHook) err = client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) + // Both hooks should be called in FIFO order on the node Expect(stack).To(Equal([]string{ - "cluster.BeforeProcess", - "shard.BeforeProcess", - "shard.AfterProcess", - "cluster.AfterProcess", + "hook1.BeforeProcess", + "hook2.BeforeProcess", + "hook2.AfterProcess", + "hook1.AfterProcess", })) }) @@ -1259,43 +1258,41 @@ var _ = Describe("ClusterClient", func() { var stack []string + // AddHook now only adds to nodes, not to cluster client itself client.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") + stack = append(stack, "hook1.BeforeProcessPipeline") err := hook(ctx, cmds) Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") + stack = append(stack, "hook1.AfterProcessPipeline") return err } }, }) - _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { - node.AddHook(&hook{ - processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { - return func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") + client.AddHook(&hook{ + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "hook2.BeforeProcessPipeline") - err := hook(ctx, cmds) + err := hook(ctx, cmds) - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "hook2.AfterProcessPipeline") - return err - } - }, - }) - return nil + return err + } + }, }) _, err = client.Pipelined(ctx, func(pipe redis.Pipeliner) error { @@ -1303,11 +1300,12 @@ var _ = Describe("ClusterClient", func() { return nil }) Expect(err).NotTo(HaveOccurred()) + // Both hooks should be called in FIFO order on the node Expect(stack).To(Equal([]string{ - "cluster.BeforeProcessPipeline", - "shard.BeforeProcessPipeline", - "shard.AfterProcessPipeline", - "cluster.AfterProcessPipeline", + "hook1.BeforeProcessPipeline", + "hook2.BeforeProcessPipeline", + "hook2.AfterProcessPipeline", + "hook1.AfterProcessPipeline", })) }) @@ -1322,43 +1320,41 @@ var _ = Describe("ClusterClient", func() { var stack []string + // AddHook now only adds to nodes, not to cluster client itself client.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") + stack = append(stack, "hook1.BeforeProcessPipeline") err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") + stack = append(stack, "hook1.AfterProcessPipeline") return err } }, }) - _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { - node.AddHook(&hook{ - processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { - return func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") + client.AddHook(&hook{ + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "hook2.BeforeProcessPipeline") - err := hook(ctx, cmds) + err := hook(ctx, cmds) - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "hook2.AfterProcessPipeline") - return err - } - }, - }) - return nil + return err + } + }, }) _, err = client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { @@ -1366,14 +1362,84 @@ var _ = Describe("ClusterClient", func() { return nil }) Expect(err).NotTo(HaveOccurred()) + // Both hooks should be called in FIFO order on the node Expect(stack).To(Equal([]string{ - "cluster.BeforeProcessPipeline", - "shard.BeforeProcessPipeline", - "shard.AfterProcessPipeline", - "cluster.AfterProcessPipeline", + "hook1.BeforeProcessPipeline", + "hook2.BeforeProcessPipeline", + "hook2.AfterProcessPipeline", + "hook1.AfterProcessPipeline", })) }) + It("passes hooks to cluster nodes", func() { + // Ensure cluster is initialized + err := client.Ping(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { + return node.Ping(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + // Track hook calls to detect if hooks are called multiple times + var mu sync.Mutex + var hookCallCount int + + // Create a hook that counts how many times it's called + testHook := &hook{ + processHook: func(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + // Only track PING commands to avoid noise from other operations + if cmd.Name() == "ping" { + mu.Lock() + hookCallCount++ + mu.Unlock() + } + return next(ctx, cmd) + } + }, + } + + // Add hook to cluster client - this should propagate to all existing nodes + client.AddHook(testHook) + + // Reset counter before test + mu.Lock() + hookCallCount = 0 + mu.Unlock() + + // Execute a single PING command through the cluster client + // This should call the hook ONLY ONCE, not twice (cluster + node) + err = client.Ping(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + mu.Lock() + clusterPingCalls := hookCallCount + mu.Unlock() + + // Reset counter + mu.Lock() + hookCallCount = 0 + mu.Unlock() + + // Execute a PING command directly on a node + // This should call the hook ONLY ONCE + err = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { + // Just test one node + return node.Ping(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + mu.Lock() + nodePingCalls := hookCallCount + mu.Unlock() + + // Verify hook is called exactly once per command, not twice + // If hooks are called twice (cluster + node), this will fail + Expect(clusterPingCalls).To(Equal(1), "Hook should be called exactly once for cluster.Ping(), not twice") + Expect(nodePingCalls).To(BeNumerically(">=", 1), "Hook should be called at least once for node.Ping()") + }) + It("should return correct replica for key", func() { client, err := client.SlaveForKey(ctx, "test") Expect(err).ToNot(HaveOccurred())