diff --git a/cluster.go b/cluster.go index d3df823b..9077e8a5 100644 --- a/cluster.go +++ b/cluster.go @@ -838,7 +838,7 @@ type ClusterClient struct { state *clusterStateHolder cmdsInfoCache *cmdsInfoCache cmdable - hooks + hooksMixin } // NewClusterClient returns a Redis Cluster client as described in @@ -855,7 +855,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process - c.hooks.setDefaultHook(defaultHook{ + c.initHooks(hooks{ dial: nil, process: c.process, pipeline: c.processPipeline, @@ -892,7 +892,7 @@ func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { - err := c.hooks.process(ctx, cmd) + err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } @@ -1190,7 +1190,7 @@ func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) { func (c *ClusterClient) Pipeline() Pipeliner { pipe := Pipeline{ - exec: pipelineExecer(c.hooks.processPipeline), + exec: pipelineExecer(c.processPipelineHook), } pipe.init() return &pipe @@ -1279,7 +1279,7 @@ func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { - _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { cn, err := node.Client.getConn(ctx) if err != nil { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) @@ -1383,7 +1383,7 @@ func (c *ClusterClient) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { cmds = wrapMultiExec(ctx, cmds) - return c.hooks.processTxPipeline(ctx, cmds) + return c.processTxPipelineHook(ctx, cmds) }, } pipe.init() @@ -1456,7 +1456,7 @@ func (c *ClusterClient) processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { cmds = wrapMultiExec(ctx, cmds) - _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { cn, err := node.Client.getConn(ctx) if err != nil { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) diff --git a/cluster_commands.go b/cluster_commands.go index fc0a9cd4..b13f8e7e 100644 --- a/cluster_commands.go +++ b/cluster_commands.go @@ -8,7 +8,7 @@ import ( func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { cmd := NewIntCmd(ctx, "dbsize") - _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { var size int64 err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error { n, err := master.DBSize(ctx).Result() @@ -30,8 +30,8 @@ func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd { cmd := NewStringCmd(ctx, "script", "load", script) - _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { - mu := &sync.Mutex{} + _ = c.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { + var mu sync.Mutex err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { val, err := shard.ScriptLoad(ctx, script).Result() if err != nil { @@ -56,7 +56,7 @@ func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCm func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd { cmd := NewStatusCmd(ctx, "script", "flush") - _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { return shard.ScriptFlush(ctx).Err() }) @@ -82,7 +82,7 @@ func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *Boo result[i] = true } - _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { var mu sync.Mutex err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { val, err := shard.ScriptExists(ctx, hashes...).Result() diff --git a/redis.go b/redis.go index ff98611b..1fe48a90 100644 --- a/redis.go +++ b/redis.go @@ -40,45 +40,39 @@ type ( ProcessPipelineHook func(ctx context.Context, cmds []Cmder) error ) -var ( - nonDialHook = func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil } - nonProcessHook = func(ctx context.Context, cmd Cmder) error { return nil } - nonProcessPipelineHook = func(ctx context.Context, cmds []Cmder) error { return nil } - nonTxProcessPipelineHook = func(ctx context.Context, cmds []Cmder) error { return nil } -) +type hooksMixin struct { + slice []Hook + initial hooks + current hooks +} -type defaultHook struct { +func (hs *hooksMixin) initHooks(hooks hooks) { + hs.initial = hooks + hs.chain() +} + +type hooks struct { dial DialHook process ProcessHook pipeline ProcessPipelineHook txPipeline ProcessPipelineHook } -func (h *defaultHook) init() { +func (h *hooks) setDefaults() { if h.dial == nil { - h.dial = nonDialHook + h.dial = func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil } } if h.process == nil { - h.process = nonProcessHook + h.process = func(ctx context.Context, cmd Cmder) error { return nil } } if h.pipeline == nil { - h.pipeline = nonProcessPipelineHook + h.pipeline = func(ctx context.Context, cmds []Cmder) error { return nil } } if h.txPipeline == nil { - h.txPipeline = nonTxProcessPipelineHook + h.txPipeline = func(ctx context.Context, cmds []Cmder) error { return nil } } } -type hooks struct { - slice []Hook - defaultHook defaultHook - - dialHook DialHook - processHook ProcessHook - processPipelineHook ProcessPipelineHook - processTxPipelineHook ProcessPipelineHook -} - // AddHook is to add a hook to the queue. // Hook is a function executed during network connection, command execution, and pipeline, // it is a first-in-first-out stack queue (FIFO). @@ -115,48 +109,43 @@ type hooks struct { // // Please note: "next(ctx, cmd)" is very important, it will call the next hook, // if "next(ctx, cmd)" is not executed, the redis command will not be executed. -func (hs *hooks) AddHook(hook Hook) { +func (hs *hooksMixin) AddHook(hook Hook) { hs.slice = append(hs.slice, hook) hs.chain() } -func (hs *hooks) chain() { - hs.defaultHook.init() +func (hs *hooksMixin) chain() { + hs.initial.setDefaults() - hs.dialHook = hs.defaultHook.dial - hs.processHook = hs.defaultHook.process - hs.processPipelineHook = hs.defaultHook.pipeline - hs.processTxPipelineHook = hs.defaultHook.txPipeline + hs.current.dial = hs.initial.dial + hs.current.process = hs.initial.process + hs.current.pipeline = hs.initial.pipeline + hs.current.txPipeline = hs.initial.txPipeline for i := len(hs.slice) - 1; i >= 0; i-- { - if wrapped := hs.slice[i].DialHook(hs.dialHook); wrapped != nil { - hs.dialHook = wrapped + if wrapped := hs.slice[i].DialHook(hs.current.dial); wrapped != nil { + hs.current.dial = wrapped } - if wrapped := hs.slice[i].ProcessHook(hs.processHook); wrapped != nil { - hs.processHook = wrapped + if wrapped := hs.slice[i].ProcessHook(hs.current.process); wrapped != nil { + hs.current.process = wrapped } - if wrapped := hs.slice[i].ProcessPipelineHook(hs.processPipelineHook); wrapped != nil { - hs.processPipelineHook = wrapped + if wrapped := hs.slice[i].ProcessPipelineHook(hs.current.pipeline); wrapped != nil { + hs.current.pipeline = wrapped } - if wrapped := hs.slice[i].ProcessPipelineHook(hs.processTxPipelineHook); wrapped != nil { - hs.processTxPipelineHook = wrapped + if wrapped := hs.slice[i].ProcessPipelineHook(hs.current.txPipeline); wrapped != nil { + hs.current.txPipeline = wrapped } } } -func (hs *hooks) clone() hooks { +func (hs *hooksMixin) clone() hooksMixin { clone := *hs l := len(clone.slice) clone.slice = clone.slice[:l:l] return clone } -func (hs *hooks) setDefaultHook(d defaultHook) { - hs.defaultHook = d - hs.chain() -} - -func (hs *hooks) withProcessHook(ctx context.Context, cmd Cmder, hook ProcessHook) error { +func (hs *hooksMixin) withProcessHook(ctx context.Context, cmd Cmder, hook ProcessHook) error { for i := len(hs.slice) - 1; i >= 0; i-- { if wrapped := hs.slice[i].ProcessHook(hook); wrapped != nil { hook = wrapped @@ -165,7 +154,7 @@ func (hs *hooks) withProcessHook(ctx context.Context, cmd Cmder, hook ProcessHoo return hook(ctx, cmd) } -func (hs *hooks) withProcessPipelineHook( +func (hs *hooksMixin) withProcessPipelineHook( ctx context.Context, cmds []Cmder, hook ProcessPipelineHook, ) error { for i := len(hs.slice) - 1; i >= 0; i-- { @@ -176,20 +165,20 @@ func (hs *hooks) withProcessPipelineHook( return hook(ctx, cmds) } -func (hs *hooks) dial(ctx context.Context, network, addr string) (net.Conn, error) { - return hs.dialHook(ctx, network, addr) +func (hs *hooksMixin) dialHook(ctx context.Context, network, addr string) (net.Conn, error) { + return hs.current.dial(ctx, network, addr) } -func (hs *hooks) process(ctx context.Context, cmd Cmder) error { - return hs.processHook(ctx, cmd) +func (hs *hooksMixin) processHook(ctx context.Context, cmd Cmder) error { + return hs.current.process(ctx, cmd) } -func (hs *hooks) processPipeline(ctx context.Context, cmds []Cmder) error { - return hs.processPipelineHook(ctx, cmds) +func (hs *hooksMixin) processPipelineHook(ctx context.Context, cmds []Cmder) error { + return hs.current.pipeline(ctx, cmds) } -func (hs *hooks) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return hs.processTxPipelineHook(ctx, cmds) +func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) error { + return hs.current.txPipeline(ctx, cmds) } //------------------------------------------------------------------------------ @@ -595,7 +584,7 @@ func (c *baseClient) context(ctx context.Context) context.Context { type Client struct { *baseClient cmdable - hooks + hooksMixin } // NewClient returns a client to the Redis Server specified by Options. @@ -608,14 +597,14 @@ func NewClient(opt *Options) *Client { }, } c.init() - c.connPool = newConnPool(opt, c.hooks.dial) + c.connPool = newConnPool(opt, c.dialHook) return &c } func (c *Client) init() { c.cmdable = c.Process - c.hooks.setDefaultHook(defaultHook{ + c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, pipeline: c.baseClient.processPipeline, @@ -642,7 +631,7 @@ func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *Client) Process(ctx context.Context, cmd Cmder) error { - err := c.hooks.process(ctx, cmd) + err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } @@ -666,7 +655,7 @@ func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmd func (c *Client) Pipeline() Pipeliner { pipe := Pipeline{ - exec: pipelineExecer(c.hooks.processPipeline), + exec: pipelineExecer(c.processPipelineHook), } pipe.init() return &pipe @@ -681,7 +670,7 @@ func (c *Client) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { cmds = wrapMultiExec(ctx, cmds) - return c.hooks.processTxPipeline(ctx, cmds) + return c.processTxPipelineHook(ctx, cmds) }, } pipe.init() @@ -764,7 +753,7 @@ type Conn struct { baseClient cmdable statefulCmdable - hooks + hooksMixin } func newConn(opt *Options, connPool pool.Pooler) *Conn { @@ -777,7 +766,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn { c.cmdable = c.Process c.statefulCmdable = c.Process - c.hooks.setDefaultHook(defaultHook{ + c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, pipeline: c.baseClient.processPipeline, @@ -788,7 +777,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn { } func (c *Conn) Process(ctx context.Context, cmd Cmder) error { - err := c.hooks.process(ctx, cmd) + err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } @@ -799,7 +788,7 @@ func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder func (c *Conn) Pipeline() Pipeliner { pipe := Pipeline{ - exec: c.hooks.processPipeline, + exec: c.processPipelineHook, } pipe.init() return &pipe @@ -814,7 +803,7 @@ func (c *Conn) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { cmds = wrapMultiExec(ctx, cmds) - return c.hooks.processTxPipeline(ctx, cmds) + return c.processTxPipelineHook(ctx, cmds) }, } pipe.init() diff --git a/ring.go b/ring.go index 9c29fb5d..9fd5f442 100644 --- a/ring.go +++ b/ring.go @@ -487,7 +487,7 @@ func (c *ringSharding) Close() error { // Otherwise you should use Redis Cluster. type Ring struct { cmdable - hooks + hooksMixin opt *RingOptions sharding *ringSharding @@ -509,7 +509,7 @@ func NewRing(opt *RingOptions) *Ring { ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process - ring.hooks.setDefaultHook(defaultHook{ + ring.initHooks(hooks{ process: ring.process, pipeline: func(ctx context.Context, cmds []Cmder) error { return ring.generalProcessPipeline(ctx, cmds, false) @@ -536,7 +536,7 @@ func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *Ring) Process(ctx context.Context, cmd Cmder) error { - err := c.hooks.process(ctx, cmd) + err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } @@ -719,7 +719,7 @@ func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder func (c *Ring) Pipeline() Pipeliner { pipe := Pipeline{ - exec: pipelineExecer(c.hooks.processPipeline), + exec: pipelineExecer(c.processPipelineHook), } pipe.init() return &pipe @@ -733,7 +733,7 @@ func (c *Ring) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { cmds = wrapMultiExec(ctx, cmds) - return c.hooks.processTxPipeline(ctx, cmds) + return c.processTxPipelineHook(ctx, cmds) }, } pipe.init() @@ -774,9 +774,9 @@ func (c *Ring) generalProcessPipeline( if tx { cmds = wrapMultiExec(ctx, cmds) - _ = shard.Client.hooks.processTxPipeline(ctx, cmds) + _ = shard.Client.processTxPipelineHook(ctx, cmds) } else { - _ = shard.Client.hooks.processPipeline(ctx, cmds) + _ = shard.Client.processPipelineHook(ctx, cmds) } }(hash, cmds) } diff --git a/sentinel.go b/sentinel.go index 533d97d2..a3834896 100644 --- a/sentinel.go +++ b/sentinel.go @@ -214,7 +214,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { } rdb.init() - connPool = newConnPool(opt, rdb.hooks.dial) + connPool = newConnPool(opt, rdb.dialHook) rdb.connPool = connPool rdb.onClose = failover.Close @@ -267,7 +267,7 @@ func masterReplicaDialer( // SentinelClient is a client for a Redis Sentinel. type SentinelClient struct { *baseClient - hooks + hooksMixin } func NewSentinelClient(opt *Options) *SentinelClient { @@ -278,17 +278,17 @@ func NewSentinelClient(opt *Options) *SentinelClient { }, } - c.hooks.setDefaultHook(defaultHook{ + c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, }) - c.connPool = newConnPool(opt, c.hooks.dial) + c.connPool = newConnPool(opt, c.dialHook) return c } func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { - err := c.hooks.process(ctx, cmd) + err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } diff --git a/tx.go b/tx.go index f686877f..a772e9e2 100644 --- a/tx.go +++ b/tx.go @@ -19,7 +19,7 @@ type Tx struct { baseClient cmdable statefulCmdable - hooks + hooksMixin } func (c *Client) newTx() *Tx { @@ -28,7 +28,7 @@ func (c *Client) newTx() *Tx { opt: c.opt, connPool: pool.NewStickyConnPool(c.connPool), }, - hooks: c.hooks.clone(), + hooksMixin: c.hooksMixin.clone(), } tx.init() return &tx @@ -38,7 +38,7 @@ func (c *Tx) init() { c.cmdable = c.Process c.statefulCmdable = c.Process - c.hooks.setDefaultHook(defaultHook{ + c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, pipeline: c.baseClient.processPipeline, @@ -47,7 +47,7 @@ func (c *Tx) init() { } func (c *Tx) Process(ctx context.Context, cmd Cmder) error { - err := c.hooks.process(ctx, cmd) + err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } @@ -102,7 +102,7 @@ func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd { func (c *Tx) Pipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds) + return c.processPipelineHook(ctx, cmds) }, } pipe.init() @@ -132,7 +132,7 @@ func (c *Tx) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { cmds = wrapMultiExec(ctx, cmds) - return c.hooks.processTxPipeline(ctx, cmds) + return c.processTxPipelineHook(ctx, cmds) }, } pipe.init()