diff --git a/cluster.go b/cluster.go index 872baad4..d3df823b 100644 --- a/cluster.go +++ b/cluster.go @@ -855,9 +855,12 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process - c.hooks.setProcess(c.process) - c.hooks.setProcessPipeline(c.processPipeline) - c.hooks.setProcessTxPipeline(c.processTxPipeline) + c.hooks.setDefaultHook(defaultHook{ + dial: nil, + process: c.process, + pipeline: c.processPipeline, + txPipeline: c.processTxPipeline, + }) return c } diff --git a/redis.go b/redis.go index 9fe0cd1a..b7f4ff36 100644 --- a/redis.go +++ b/redis.go @@ -40,8 +40,39 @@ type ( ProcessPipelineHook func(ctx context.Context, cmds []Cmder) error ) +var ( + nonDialHook = func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil } + nonProcessHook = func(ctx context.Context, cmd Cmder) error { return nil } + nonProcessPipelineHook = func(ctx context.Context, cmds []Cmder) error { return nil } + nonTxProcessPipelineHook = func(ctx context.Context, cmds []Cmder) error { return nil } +) + +type defaultHook struct { + dial DialHook + process ProcessHook + pipeline ProcessPipelineHook + txPipeline ProcessPipelineHook +} + +func (h *defaultHook) init() { + if h.dial == nil { + h.dial = nonDialHook + } + if h.process == nil { + h.process = nonProcessHook + } + if h.pipeline == nil { + h.pipeline = nonProcessPipelineHook + } + if h.txPipeline == nil { + h.txPipeline = nonTxProcessPipelineHook + } +} + type hooks struct { - slice []Hook + slice []Hook + defaultHook defaultHook + dialHook DialHook processHook ProcessHook processPipelineHook ProcessPipelineHook @@ -87,10 +118,31 @@ type hooks struct { // if "next(ctx, cmd)" is not executed in hook-1, the redis command will not be executed. func (hs *hooks) AddHook(hook Hook) { hs.slice = append(hs.slice, hook) - hs.dialHook = hook.DialHook(hs.dialHook) - hs.processHook = hook.ProcessHook(hs.processHook) - hs.processPipelineHook = hook.ProcessPipelineHook(hs.processPipelineHook) - hs.processTxPipelineHook = hook.ProcessPipelineHook(hs.processTxPipelineHook) + hs.chain() +} + +func (hs *hooks) chain() { + hs.defaultHook.init() + + hs.dialHook = hs.defaultHook.dial + hs.processHook = hs.defaultHook.process + hs.processPipelineHook = hs.defaultHook.pipeline + hs.processTxPipelineHook = hs.defaultHook.txPipeline + + for i := len(hs.slice) - 1; i >= 0; i-- { + if wrapped := hs.slice[i].DialHook(hs.dialHook); wrapped != nil { + hs.dialHook = wrapped + } + if wrapped := hs.slice[i].ProcessHook(hs.processHook); wrapped != nil { + hs.processHook = wrapped + } + if wrapped := hs.slice[i].ProcessPipelineHook(hs.processPipelineHook); wrapped != nil { + hs.processPipelineHook = wrapped + } + if wrapped := hs.slice[i].ProcessPipelineHook(hs.processTxPipelineHook); wrapped != nil { + hs.processTxPipelineHook = wrapped + } + } } func (hs *hooks) clone() hooks { @@ -100,40 +152,9 @@ func (hs *hooks) clone() hooks { return clone } -func (hs *hooks) setDial(dial DialHook) { - hs.dialHook = dial - for _, h := range hs.slice { - if wrapped := h.DialHook(hs.dialHook); wrapped != nil { - hs.dialHook = wrapped - } - } -} - -func (hs *hooks) setProcess(process ProcessHook) { - hs.processHook = process - for _, h := range hs.slice { - if wrapped := h.ProcessHook(hs.processHook); wrapped != nil { - hs.processHook = wrapped - } - } -} - -func (hs *hooks) setProcessPipeline(processPipeline ProcessPipelineHook) { - hs.processPipelineHook = processPipeline - for _, h := range hs.slice { - if wrapped := h.ProcessPipelineHook(hs.processPipelineHook); wrapped != nil { - hs.processPipelineHook = wrapped - } - } -} - -func (hs *hooks) setProcessTxPipeline(processTxPipeline ProcessPipelineHook) { - hs.processTxPipelineHook = processTxPipeline - for _, h := range hs.slice { - if wrapped := h.ProcessPipelineHook(hs.processTxPipelineHook); wrapped != nil { - hs.processTxPipelineHook = wrapped - } - } +func (hs *hooks) setDefaultHook(d defaultHook) { + hs.defaultHook = d + hs.chain() } func (hs *hooks) withProcessHook(ctx context.Context, cmd Cmder, hook ProcessHook) error { @@ -595,10 +616,12 @@ func NewClient(opt *Options) *Client { func (c *Client) init() { c.cmdable = c.Process - c.hooks.setDial(c.baseClient.dial) - c.hooks.setProcess(c.baseClient.process) - c.hooks.setProcessPipeline(c.baseClient.processPipeline) - c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline) + c.hooks.setDefaultHook(defaultHook{ + dial: c.baseClient.dial, + process: c.baseClient.process, + pipeline: c.baseClient.processPipeline, + txPipeline: c.baseClient.processTxPipeline, + }) } func (c *Client) WithTimeout(timeout time.Duration) *Client { @@ -755,11 +778,12 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn { c.cmdable = c.Process c.statefulCmdable = c.Process - - c.hooks.setDial(c.baseClient.dial) - c.hooks.setProcess(c.baseClient.process) - c.hooks.setProcessPipeline(c.baseClient.processPipeline) - c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline) + c.hooks.setDefaultHook(defaultHook{ + dial: c.baseClient.dial, + process: c.baseClient.process, + pipeline: c.baseClient.processPipeline, + txPipeline: c.baseClient.processTxPipeline, + }) return &c } diff --git a/ring.go b/ring.go index bc299da0..9c29fb5d 100644 --- a/ring.go +++ b/ring.go @@ -509,12 +509,14 @@ func NewRing(opt *RingOptions) *Ring { ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process - ring.hooks.setProcess(ring.process) - ring.hooks.setProcessPipeline(func(ctx context.Context, cmds []Cmder) error { - return ring.generalProcessPipeline(ctx, cmds, false) - }) - ring.hooks.setProcessTxPipeline(func(ctx context.Context, cmds []Cmder) error { - return ring.generalProcessPipeline(ctx, cmds, true) + ring.hooks.setDefaultHook(defaultHook{ + process: ring.process, + pipeline: func(ctx context.Context, cmds []Cmder) error { + return ring.generalProcessPipeline(ctx, cmds, false) + }, + txPipeline: func(ctx context.Context, cmds []Cmder) error { + return ring.generalProcessPipeline(ctx, cmds, true) + }, }) go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) diff --git a/sentinel.go b/sentinel.go index 1feeb039..533d97d2 100644 --- a/sentinel.go +++ b/sentinel.go @@ -278,8 +278,10 @@ func NewSentinelClient(opt *Options) *SentinelClient { }, } - c.hooks.setDial(c.baseClient.dial) - c.hooks.setProcess(c.baseClient.process) + c.hooks.setDefaultHook(defaultHook{ + dial: c.baseClient.dial, + process: c.baseClient.process, + }) c.connPool = newConnPool(opt, c.hooks.dial) return c diff --git a/tx.go b/tx.go index e720e687..f686877f 100644 --- a/tx.go +++ b/tx.go @@ -38,10 +38,12 @@ func (c *Tx) init() { c.cmdable = c.Process c.statefulCmdable = c.Process - c.hooks.setDial(c.baseClient.dial) - c.hooks.setProcess(c.baseClient.process) - c.hooks.setProcessPipeline(c.baseClient.processPipeline) - c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline) + c.hooks.setDefaultHook(defaultHook{ + dial: c.baseClient.dial, + process: c.baseClient.process, + pipeline: c.baseClient.processPipeline, + txPipeline: c.baseClient.processTxPipeline, + }) } func (c *Tx) Process(ctx context.Context, cmd Cmder) error {