diff --git a/cluster.go b/cluster.go index 68194c8d..dbb8deff 100644 --- a/cluster.go +++ b/cluster.go @@ -645,16 +645,13 @@ func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) { type ClusterClient struct { cmdable - ctx context.Context - opt *ClusterOptions nodes *clusterNodes state *clusterStateHolder cmdsInfoCache *cmdsInfoCache - process func(Cmder) error - processPipeline func([]Cmder) error - processTxPipeline func([]Cmder) error + ctx context.Context + hooks } // NewClusterClient returns a Redis Cluster client as described in @@ -669,10 +666,6 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) - c.process = c.defaultProcess - c.processPipeline = c.defaultProcessPipeline - c.processTxPipeline = c.defaultProcessTxPipeline - c.init() if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) @@ -685,13 +678,6 @@ func (c *ClusterClient) init() { c.cmdable.setProcessor(c.Process) } -// ReloadState reloads cluster state. If available it calls ClusterSlots func -// to get cluster slots information. -func (c *ClusterClient) ReloadState() error { - _, err := c.state.Reload() - return err -} - func (c *ClusterClient) Context() context.Context { if c.ctx != nil { return c.ctx @@ -709,9 +695,10 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { } func (c *ClusterClient) clone() *ClusterClient { - cp := *c - cp.init() - return &cp + clone := *c + clone.hooks.copy() + clone.init() + return &clone } // Options returns read-only Options that were used to create the client. @@ -719,164 +706,10 @@ func (c *ClusterClient) Options() *ClusterOptions { return c.opt } -func (c *ClusterClient) retryBackoff(attempt int) time.Duration { - return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) -} - -func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { - addrs, err := c.nodes.Addrs() - if err != nil { - return nil, err - } - - var firstErr error - for _, addr := range addrs { - node, err := c.nodes.Get(addr) - if err != nil { - return nil, err - } - if node == nil { - continue - } - - info, err := node.Client.Command().Result() - if err == nil { - return info, nil - } - if firstErr == nil { - firstErr = err - } - } - return nil, firstErr -} - -func (c *ClusterClient) cmdInfo(name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get() - if err != nil { - return nil - } - - info := cmdsInfo[name] - if info == nil { - internal.Logf("info for cmd=%s not found", name) - } - return info -} - -func cmdSlot(cmd Cmder, pos int) int { - if pos == 0 { - return hashtag.RandomSlot() - } - firstKey := cmd.stringArg(pos) - return hashtag.Slot(firstKey) -} - -func (c *ClusterClient) cmdSlot(cmd Cmder) int { - args := cmd.Args() - if args[0] == "cluster" && args[1] == "getkeysinslot" { - return args[2].(int) - } - - cmdInfo := c.cmdInfo(cmd.Name()) - return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo)) -} - -func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { - state, err := c.state.Get() - if err != nil { - return 0, nil, err - } - - cmdInfo := c.cmdInfo(cmd.Name()) - slot := c.cmdSlot(cmd) - - if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly { - if c.opt.RouteByLatency { - node, err := state.slotClosestNode(slot) - return slot, node, err - } - - if c.opt.RouteRandomly { - node := state.slotRandomNode(slot) - return slot, node, nil - } - - node, err := state.slotSlaveNode(slot) - return slot, node, err - } - - node, err := state.slotMasterNode(slot) - return slot, node, err -} - -func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { - state, err := c.state.Get() - if err != nil { - return nil, err - } - - nodes := state.slotNodes(slot) - if len(nodes) > 0 { - return nodes[0], nil - } - return c.nodes.Random() -} - -func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { - if len(keys) == 0 { - return fmt.Errorf("redis: Watch requires at least one key") - } - - slot := hashtag.Slot(keys[0]) - for _, key := range keys[1:] { - if hashtag.Slot(key) != slot { - err := fmt.Errorf("redis: Watch requires all keys to be in the same slot") - return err - } - } - - node, err := c.slotMasterNode(slot) - if err != nil { - return err - } - - for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { - if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) - } - - err = node.Client.Watch(fn, keys...) - if err == nil { - break - } - if err != Nil { - c.state.LazyReload() - } - - moved, ask, addr := internal.IsMovedError(err) - if moved || ask { - node, err = c.nodes.GetOrCreate(addr) - if err != nil { - return err - } - continue - } - - if err == pool.ErrClosed || internal.IsReadOnlyError(err) { - node, err = c.slotMasterNode(slot) - if err != nil { - return err - } - continue - } - - if internal.IsRetryableError(err, true) { - continue - } - - return err - } - +// ReloadState reloads cluster state. If available it calls ClusterSlots func +// to get cluster slots information. +func (c *ClusterClient) ReloadState() error { + _, err := c.state.Reload() return err } @@ -895,17 +728,11 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd { return cmd } -func (c *ClusterClient) WrapProcess( - fn func(oldProcess func(Cmder) error) func(Cmder) error, -) { - c.process = fn(c.process) -} - func (c *ClusterClient) Process(cmd Cmder) error { - return c.process(cmd) + return c.hooks.process(c.ctx, cmd, c.process) } -func (c *ClusterClient) defaultProcess(cmd Cmder) error { +func (c *ClusterClient) process(cmd Cmder) error { var node *clusterNode var ask bool for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { @@ -1194,14 +1021,11 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(fn) } -func (c *ClusterClient) WrapProcessPipeline( - fn func(oldProcess func([]Cmder) error) func([]Cmder) error, -) { - c.processPipeline = fn(c.processPipeline) - c.processTxPipeline = fn(c.processTxPipeline) +func (c *ClusterClient) processPipeline(cmds []Cmder) error { + return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) } -func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { +func (c *ClusterClient) _processPipeline(cmds []Cmder) error { cmdsMap := newCmdsMap() err := c.mapCmdsByNode(cmds, cmdsMap) if err != nil { @@ -1391,7 +1215,11 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.TxPipeline().Pipelined(fn) } -func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { +func (c *ClusterClient) processTxPipeline(cmds []Cmder) error { + return c.hooks.processPipeline(c.ctx, cmds, c._processTxPipeline) +} + +func (c *ClusterClient) _processTxPipeline(cmds []Cmder) error { state, err := c.state.Get() if err != nil { return err @@ -1529,6 +1357,64 @@ func (c *ClusterClient) txPipelineReadQueued( return nil } +func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { + if len(keys) == 0 { + return fmt.Errorf("redis: Watch requires at least one key") + } + + slot := hashtag.Slot(keys[0]) + for _, key := range keys[1:] { + if hashtag.Slot(key) != slot { + err := fmt.Errorf("redis: Watch requires all keys to be in the same slot") + return err + } + } + + node, err := c.slotMasterNode(slot) + if err != nil { + return err + } + + for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + if attempt > 0 { + time.Sleep(c.retryBackoff(attempt)) + } + + err = node.Client.Watch(fn, keys...) + if err == nil { + break + } + if err != Nil { + c.state.LazyReload() + } + + moved, ask, addr := internal.IsMovedError(err) + if moved || ask { + node, err = c.nodes.GetOrCreate(addr) + if err != nil { + return err + } + continue + } + + if err == pool.ErrClosed || internal.IsReadOnlyError(err) { + node, err = c.slotMasterNode(slot) + if err != nil { + return err + } + continue + } + + if internal.IsRetryableError(err, true) { + continue + } + + return err + } + + return err +} + func (c *ClusterClient) pubSub() *PubSub { var node *clusterNode pubsub := &PubSub{ @@ -1590,6 +1476,109 @@ func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { return pubsub } +func (c *ClusterClient) retryBackoff(attempt int) time.Duration { + return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) +} + +func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { + addrs, err := c.nodes.Addrs() + if err != nil { + return nil, err + } + + var firstErr error + for _, addr := range addrs { + node, err := c.nodes.Get(addr) + if err != nil { + return nil, err + } + if node == nil { + continue + } + + info, err := node.Client.Command().Result() + if err == nil { + return info, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, firstErr +} + +func (c *ClusterClient) cmdInfo(name string) *CommandInfo { + cmdsInfo, err := c.cmdsInfoCache.Get() + if err != nil { + return nil + } + + info := cmdsInfo[name] + if info == nil { + internal.Logf("info for cmd=%s not found", name) + } + return info +} + +func cmdSlot(cmd Cmder, pos int) int { + if pos == 0 { + return hashtag.RandomSlot() + } + firstKey := cmd.stringArg(pos) + return hashtag.Slot(firstKey) +} + +func (c *ClusterClient) cmdSlot(cmd Cmder) int { + args := cmd.Args() + if args[0] == "cluster" && args[1] == "getkeysinslot" { + return args[2].(int) + } + + cmdInfo := c.cmdInfo(cmd.Name()) + return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo)) +} + +func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { + state, err := c.state.Get() + if err != nil { + return 0, nil, err + } + + cmdInfo := c.cmdInfo(cmd.Name()) + slot := c.cmdSlot(cmd) + + if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly { + if c.opt.RouteByLatency { + node, err := state.slotClosestNode(slot) + return slot, node, err + } + + if c.opt.RouteRandomly { + node := state.slotRandomNode(slot) + return slot, node, nil + } + + node, err := state.slotSlaveNode(slot) + return slot, node, err + } + + node, err := state.slotMasterNode(slot) + return slot, node, err +} + +func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { + state, err := c.state.Get() + if err != nil { + return nil, err + } + + nodes := state.slotNodes(slot) + if len(nodes) > 0 { + return nodes[0], nil + } + return c.nodes.Random() +} + func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { for _, n := range nodes { if n == node { diff --git a/command.go b/command.go index 6f058959..177459ed 100644 --- a/command.go +++ b/command.go @@ -100,8 +100,14 @@ type baseCmd struct { var _ Cmder = (*Cmd)(nil) -func (cmd *baseCmd) Err() error { - return cmd.err +func (cmd *baseCmd) Name() string { + if len(cmd._args) > 0 { + // Cmd name must be lower cased. + s := internal.ToLower(cmd.stringArg(0)) + cmd._args[0] = s + return s + } + return "" } func (cmd *baseCmd) Args() []interface{} { @@ -116,14 +122,8 @@ func (cmd *baseCmd) stringArg(pos int) string { return s } -func (cmd *baseCmd) Name() string { - if len(cmd._args) > 0 { - // Cmd name must be lower cased. - s := internal.ToLower(cmd.stringArg(0)) - cmd._args[0] = s - return s - } - return "" +func (cmd *baseCmd) Err() error { + return cmd.err } func (cmd *baseCmd) readTimeout() *time.Duration { diff --git a/example_instrumentation_test.go b/example_instrumentation_test.go index f9444a6e..65ef59cc 100644 --- a/example_instrumentation_test.go +++ b/example_instrumentation_test.go @@ -1,44 +1,54 @@ package redis_test import ( + "context" "fmt" "github.com/go-redis/redis" ) +type redisHook struct{} + +var _ redis.Hook = redisHook{} + +func (redisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + fmt.Printf("starting processing: <%s>\n", cmd) + return ctx, nil +} + +func (redisHook) AfterProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + fmt.Printf("finished processing: <%s>\n", cmd) + return ctx, nil +} + +func (redisHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + fmt.Printf("pipeline starting processing: %v\n", cmds) + return ctx, nil +} + +func (redisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + fmt.Printf("pipeline finished processing: %v\n", cmds) + return ctx, nil +} + func Example_instrumentation() { - redisdb := redis.NewClient(&redis.Options{ + rdb := redis.NewClient(&redis.Options{ Addr: ":6379", }) - redisdb.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { - return func(cmd redis.Cmder) error { - fmt.Printf("starting processing: <%s>\n", cmd) - err := old(cmd) - fmt.Printf("finished processing: <%s>\n", cmd) - return err - } - }) + rdb.AddHook(redisHook{}) - redisdb.Ping() + rdb.Ping() // Output: starting processing: // finished processing: } func ExamplePipeline_instrumentation() { - redisdb := redis.NewClient(&redis.Options{ + rdb := redis.NewClient(&redis.Options{ Addr: ":6379", }) + rdb.AddHook(redisHook{}) - redisdb.WrapProcessPipeline(func(old func([]redis.Cmder) error) func([]redis.Cmder) error { - return func(cmds []redis.Cmder) error { - fmt.Printf("pipeline starting processing: %v\n", cmds) - err := old(cmds) - fmt.Printf("pipeline finished processing: %v\n", cmds) - return err - } - }) - - redisdb.Pipelined(func(pipe redis.Pipeliner) error { + rdb.Pipelined(func(pipe redis.Pipeliner) error { pipe.Ping() pipe.Ping() return nil diff --git a/redis.go b/redis.go index a7673760..61bbe966 100644 --- a/redis.go +++ b/redis.go @@ -23,24 +23,114 @@ func SetLogger(logger *log.Logger) { internal.Logger = logger } +//------------------------------------------------------------------------------ + +type Hook interface { + BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) + AfterProcess(ctx context.Context, cmd Cmder) (context.Context, error) + + BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) + AfterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) +} + +type hooks struct { + hooks []Hook +} + +func (hs *hooks) AddHook(hook Hook) { + hs.hooks = append(hs.hooks, hook) +} + +func (hs *hooks) copy() { + hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] +} + +func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error { + ctx, err := hs.beforeProcess(ctx, cmd) + if err != nil { + return err + } + + cmdErr := fn(cmd) + + _, err = hs.afterProcess(ctx, cmd) + if err != nil { + return err + } + + return cmdErr +} + +func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) { + for _, h := range hs.hooks { + var err error + ctx, err = h.BeforeProcess(ctx, cmd) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, error) { + for _, h := range hs.hooks { + var err error + ctx, err = h.AfterProcess(ctx, cmd) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error { + ctx, err := hs.beforeProcessPipeline(ctx, cmds) + if err != nil { + return err + } + + cmdsErr := fn(cmds) + + _, err = hs.afterProcessPipeline(ctx, cmds) + if err != nil { + return err + } + + return cmdsErr +} + +func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) { + for _, h := range hs.hooks { + var err error + ctx, err = h.BeforeProcessPipeline(ctx, cmds) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) { + for _, h := range hs.hooks { + var err error + ctx, err = h.AfterProcessPipeline(ctx, cmds) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +//------------------------------------------------------------------------------ + type baseClient struct { opt *Options connPool pool.Pooler limiter Limiter - process func(Cmder) error - processPipeline func([]Cmder) error - processTxPipeline func([]Cmder) error - onClose func() error // hook called when client is closed } -func (c *baseClient) init() { - c.process = c.defaultProcess - c.processPipeline = c.defaultProcessPipeline - c.processTxPipeline = c.defaultProcessTxPipeline -} - func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } @@ -159,22 +249,11 @@ func (c *baseClient) initConn(cn *pool.Conn) error { // Do creates a Cmd from the args and processes the cmd. func (c *baseClient) Do(args ...interface{}) *Cmd { cmd := NewCmd(args...) - _ = c.Process(cmd) + _ = c.process(cmd) return cmd } -// WrapProcess wraps function that processes Redis commands. -func (c *baseClient) WrapProcess( - fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error, -) { - c.process = fn(c.process) -} - -func (c *baseClient) Process(cmd Cmder) error { - return c.process(cmd) -} - -func (c *baseClient) defaultProcess(cmd Cmder) error { +func (c *baseClient) process(cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { time.Sleep(c.retryBackoff(attempt)) @@ -249,18 +328,11 @@ func (c *baseClient) getAddr() string { return c.opt.Addr } -func (c *baseClient) WrapProcessPipeline( - fn func(oldProcess func([]Cmder) error) func([]Cmder) error, -) { - c.processPipeline = fn(c.processPipeline) - c.processTxPipeline = fn(c.processTxPipeline) -} - -func (c *baseClient) defaultProcessPipeline(cmds []Cmder) error { +func (c *baseClient) processPipeline(cmds []Cmder) error { return c.generalProcessPipeline(cmds, c.pipelineProcessCmds) } -func (c *baseClient) defaultProcessTxPipeline(cmds []Cmder) error { +func (c *baseClient) processTxPipeline(cmds []Cmder) error { return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds) } @@ -388,6 +460,7 @@ type Client struct { cmdable ctx context.Context + hooks } // NewClient returns a client to the Redis Server specified by Options. @@ -400,7 +473,6 @@ func NewClient(opt *Options) *Client { connPool: newConnPool(opt), }, } - c.baseClient.init() c.init() return &c @@ -427,9 +499,22 @@ func (c *Client) WithContext(ctx context.Context) *Client { } func (c *Client) clone() *Client { - cp := *c - cp.init() - return &cp + clone := *c + clone.hooks.copy() + clone.init() + return &clone +} + +func (c *Client) Process(cmd Cmder) error { + return c.hooks.process(c.ctx, cmd, c.baseClient.process) +} + +func (c *Client) processPipeline(cmds []Cmder) error { + return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline) +} + +func (c *Client) processTxPipeline(cmds []Cmder) error { + return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline) } // Options returns read-only Options that were used to create the client. @@ -547,11 +632,14 @@ func newConn(opt *Options, cn *pool.Conn) *Conn { connPool: pool.NewSingleConnPool(cn), }, } - c.baseClient.init() c.statefulCmdable.setProcessor(c.Process) return &c } +func (c *Conn) Process(cmd Cmder) error { + return c.baseClient.process(cmd) +} + func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(fn) } diff --git a/redis_test.go b/redis_test.go index f46728f9..d600894b 100644 --- a/redis_test.go +++ b/redis_test.go @@ -224,43 +224,6 @@ var _ = Describe("Client", func() { Expect(err).NotTo(HaveOccurred()) Expect(got).To(Equal(bigVal)) }) - - It("should call WrapProcess", func() { - var fnCalled bool - - client.WrapProcess(func(old func(redis.Cmder) error) func(redis.Cmder) error { - return func(cmd redis.Cmder) error { - fnCalled = true - return old(cmd) - } - }) - - Expect(client.Ping().Err()).NotTo(HaveOccurred()) - Expect(fnCalled).To(BeTrue()) - }) - - It("should call WrapProcess after WithContext", func() { - var fn1Called, fn2Called bool - - client.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { - return func(cmd redis.Cmder) error { - fn1Called = true - return old(cmd) - } - }) - - client2 := client.WithContext(client.Context()) - client2.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { - return func(cmd redis.Cmder) error { - fn2Called = true - return old(cmd) - } - }) - - Expect(client2.Ping().Err()).NotTo(HaveOccurred()) - Expect(fn2Called).To(BeTrue()) - Expect(fn1Called).To(BeTrue()) - }) }) var _ = Describe("Client timeout", func() { diff --git a/ring.go b/ring.go index 5956b71a..a98bee13 100644 --- a/ring.go +++ b/ring.go @@ -340,14 +340,12 @@ func (c *ringShards) Close() error { type Ring struct { cmdable - ctx context.Context - opt *RingOptions shards *ringShards cmdsInfoCache *cmdsInfoCache - process func(Cmder) error - processPipeline func([]Cmder) error + ctx context.Context + hooks } func NewRing(opt *RingOptions) *Ring { @@ -358,10 +356,6 @@ func NewRing(opt *RingOptions) *Ring { shards: newRingShards(opt), } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) - - ring.process = ring.defaultProcess - ring.processPipeline = ring.defaultProcessPipeline - ring.init() for name, addr := range opt.Addrs { @@ -536,17 +530,11 @@ func (c *Ring) Do(args ...interface{}) *Cmd { return cmd } -func (c *Ring) WrapProcess( - fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error, -) { - c.process = fn(c.process) -} - func (c *Ring) Process(cmd Cmder) error { - return c.process(cmd) + return c.hooks.process(c.ctx, cmd, c.process) } -func (c *Ring) defaultProcess(cmd Cmder) error { +func (c *Ring) process(cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { time.Sleep(c.retryBackoff(attempt)) @@ -581,13 +569,11 @@ func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(fn) } -func (c *Ring) WrapProcessPipeline( - fn func(oldProcess func([]Cmder) error) func([]Cmder) error, -) { - c.processPipeline = fn(c.processPipeline) +func (c *Ring) processPipeline(cmds []Cmder) error { + return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) } -func (c *Ring) defaultProcessPipeline(cmds []Cmder) error { +func (c *Ring) _processPipeline(cmds []Cmder) error { cmdsMap := make(map[string][]Cmder) for _, cmd := range cmds { cmdInfo := c.cmdInfo(cmd.Name()) diff --git a/ring_test.go b/ring_test.go index d498e034..4ff08986 100644 --- a/ring_test.go +++ b/ring_test.go @@ -1,7 +1,6 @@ package redis_test import ( - "context" "crypto/rand" "fmt" "net" @@ -105,27 +104,6 @@ var _ = Describe("Redis Ring", func() { Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100")) }) - It("propagates process for WithContext", func() { - var fromWrap []string - wrapper := func(oldProcess func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { - return func(cmd redis.Cmder) error { - fromWrap = append(fromWrap, cmd.Name()) - - return oldProcess(cmd) - } - } - - ctx := context.Background() - ring = ring.WithContext(ctx) - ring.WrapProcess(wrapper) - - ring.Ping() - Expect(fromWrap).To(Equal([]string{"ping"})) - - ring.Ping() - Expect(fromWrap).To(Equal([]string{"ping", "ping"})) - }) - Describe("pipeline", func() { It("distributes keys", func() { pipe := ring.Pipeline() diff --git a/sentinel.go b/sentinel.go index 3043fbda..3f605c72 100644 --- a/sentinel.go +++ b/sentinel.go @@ -1,6 +1,7 @@ package redis import ( + "context" "crypto/tls" "errors" "net" @@ -93,7 +94,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { onClose: failover.Close, }, } - c.baseClient.init() c.cmdable.setProcessor(c.Process) return &c @@ -103,6 +103,8 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { type SentinelClient struct { baseClient + + ctx context.Context } func NewSentinelClient(opt *Options) *SentinelClient { @@ -113,10 +115,34 @@ func NewSentinelClient(opt *Options) *SentinelClient { connPool: newConnPool(opt), }, } - c.baseClient.init() return c } +func (c *SentinelClient) Context() context.Context { + if c.ctx != nil { + return c.ctx + } + return context.Background() +} + +func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { + if ctx == nil { + panic("nil context") + } + c2 := c.clone() + c2.ctx = ctx + return c2 +} + +func (c *SentinelClient) clone() *SentinelClient { + clone := *c + return &clone +} + +func (c *SentinelClient) Process(cmd Cmder) error { + return c.baseClient.process(cmd) +} + func (c *SentinelClient) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, diff --git a/tx.go b/tx.go index fb3e6331..afeed147 100644 --- a/tx.go +++ b/tx.go @@ -1,6 +1,8 @@ package redis import ( + "context" + "github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/proto" ) @@ -15,6 +17,8 @@ const TxFailedErr = proto.RedisError("redis: transaction failed") type Tx struct { statefulCmdable baseClient + + ctx context.Context } func (c *Client) newTx() *Tx { @@ -23,12 +27,42 @@ func (c *Client) newTx() *Tx { opt: c.opt, connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), }, + ctx: c.ctx, } - tx.baseClient.init() - tx.statefulCmdable.setProcessor(tx.Process) + tx.init() return &tx } +func (c *Tx) init() { + c.statefulCmdable.setProcessor(c.Process) +} + +func (c *Tx) Context() context.Context { + if c.ctx != nil { + return c.ctx + } + return context.Background() +} + +func (c *Tx) WithContext(ctx context.Context) *Tx { + if ctx == nil { + panic("nil context") + } + c2 := c.clone() + c2.ctx = ctx + return c2 +} + +func (c *Tx) clone() *Tx { + clone := *c + clone.init() + return &clone +} + +func (c *Tx) Process(cmd Cmder) error { + return c.baseClient.process(cmd) +} + // Watch prepares a transaction and marks the keys to be watched // for conditional execution if there are any keys. // diff --git a/universal.go b/universal.go index 03bfa0fa..76bc1b2f 100644 --- a/universal.go +++ b/universal.go @@ -1,6 +1,7 @@ package redis import ( + "context" "crypto/tls" "time" ) @@ -147,15 +148,14 @@ func (o *UniversalOptions) simple() *Options { // -------------------------------------------------------------------- // UniversalClient is an abstract client which - based on the provided options - -// can connect to either clusters, or sentinel-backed failover instances or simple -// single-instance servers. This can be useful for testing cluster-specific -// applications locally. +// can connect to either clusters, or sentinel-backed failover instances +// or simple single-instance servers. This can be useful for testing +// cluster-specific applications locally. type UniversalClient interface { Cmdable + Context() context.Context Watch(fn func(*Tx) error, keys ...string) error Process(cmd Cmder) error - WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error) - WrapProcessPipeline(fn func(oldProcess func([]Cmder) error) func([]Cmder) error) Subscribe(channels ...string) *PubSub PSubscribe(channels ...string) *PubSub Close() error