1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

Add ctx as first arg

This commit is contained in:
Vladimir Mihailenco
2020-03-11 16:26:42 +02:00
parent 64bb0b7f3a
commit f5593121e0
36 changed files with 3200 additions and 2970 deletions

View File

@ -191,7 +191,7 @@ func (n *clusterNode) updateLatency() {
var latency uint32
for i := 0; i < probes; i++ {
start := time.Now()
n.Client.Ping()
n.Client.Ping(context.TODO())
probe := uint32(time.Since(start) / time.Microsecond)
latency = (latency + probe) / 2
}
@ -588,20 +588,20 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode {
//------------------------------------------------------------------------------
type clusterStateHolder struct {
load func() (*clusterState, error)
load func(ctx context.Context) (*clusterState, error)
state atomic.Value
reloading uint32 // atomic
}
func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder {
func newClusterStateHolder(fn func(ctx context.Context) (*clusterState, error)) *clusterStateHolder {
return &clusterStateHolder{
load: fn,
}
}
func (c *clusterStateHolder) Reload() (*clusterState, error) {
state, err := c.load()
func (c *clusterStateHolder) Reload(ctx context.Context) (*clusterState, error) {
state, err := c.load(ctx)
if err != nil {
return nil, err
}
@ -609,14 +609,14 @@ func (c *clusterStateHolder) Reload() (*clusterState, error) {
return state, nil
}
func (c *clusterStateHolder) LazyReload() {
func (c *clusterStateHolder) LazyReload(ctx context.Context) {
if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) {
return
}
go func() {
defer atomic.StoreUint32(&c.reloading, 0)
_, err := c.Reload()
_, err := c.Reload(ctx)
if err != nil {
return
}
@ -624,24 +624,24 @@ func (c *clusterStateHolder) LazyReload() {
}()
}
func (c *clusterStateHolder) Get() (*clusterState, error) {
func (c *clusterStateHolder) Get(ctx context.Context) (*clusterState, error) {
v := c.state.Load()
if v != nil {
state := v.(*clusterState)
if time.Since(state.createdAt) > time.Minute {
c.LazyReload()
c.LazyReload(ctx)
}
return state, nil
}
return c.Reload()
return c.Reload(ctx)
}
func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) {
state, err := c.Reload()
func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, error) {
state, err := c.Reload(ctx)
if err == nil {
return state, nil
}
return c.Get()
return c.Get(ctx)
}
//------------------------------------------------------------------------------
@ -708,8 +708,8 @@ func (c *ClusterClient) Options() *ClusterOptions {
// ReloadState reloads cluster state. If available it calls ClusterSlots func
// to get cluster slots information.
func (c *ClusterClient) ReloadState() error {
_, err := c.state.Reload()
func (c *ClusterClient) ReloadState(ctx context.Context) error {
_, err := c.state.Reload(ctx)
return err
}
@ -722,21 +722,13 @@ func (c *ClusterClient) Close() error {
}
// Do creates a Cmd from the args and processes the cmd.
func (c *ClusterClient) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *ClusterClient) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}
@ -765,7 +757,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if node == nil {
var err error
node, err = c.cmdNode(cmdInfo, slot)
node, err = c.cmdNode(ctx, cmdInfo, slot)
if err != nil {
return err
}
@ -773,13 +765,13 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if ask {
pipe := node.Client.Pipeline()
_ = pipe.Process(NewCmd("asking"))
_ = pipe.Process(cmd)
_, lastErr = pipe.ExecContext(ctx)
_ = pipe.Process(ctx, NewCmd(ctx, "asking"))
_ = pipe.Process(ctx, cmd)
_, lastErr = pipe.Exec(ctx)
_ = pipe.Close()
ask = false
} else {
lastErr = node.Client.ProcessContext(ctx, cmd)
lastErr = node.Client.Process(ctx, cmd)
}
// If there is no error - we are done.
@ -787,7 +779,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
return nil
}
if lastErr != Nil {
c.state.LazyReload()
c.state.LazyReload(ctx)
}
if lastErr == pool.ErrClosed || isReadOnlyError(lastErr) {
node = nil
@ -832,8 +824,11 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
// ForEachMaster concurrently calls the fn on each master node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
state, err := c.state.ReloadOrGet()
func (c *ClusterClient) ForEachMaster(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil {
return err
}
@ -845,7 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
err := fn(ctx, node.Client)
if err != nil {
select {
case errCh <- err:
@ -867,8 +862,11 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
// ForEachSlave concurrently calls the fn on each slave node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
state, err := c.state.ReloadOrGet()
func (c *ClusterClient) ForEachSlave(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil {
return err
}
@ -880,7 +878,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
err := fn(ctx, node.Client)
if err != nil {
select {
case errCh <- err:
@ -902,8 +900,11 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
// ForEachNode concurrently calls the fn on each known node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
state, err := c.state.ReloadOrGet()
func (c *ClusterClient) ForEachNode(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil {
return err
}
@ -913,7 +914,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
worker := func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
err := fn(ctx, node.Client)
if err != nil {
select {
case errCh <- err:
@ -945,7 +946,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
func (c *ClusterClient) PoolStats() *PoolStats {
var acc PoolStats
state, _ := c.state.Get()
state, _ := c.state.Get(context.TODO())
if state == nil {
return &acc
}
@ -975,7 +976,7 @@ func (c *ClusterClient) PoolStats() *PoolStats {
return &acc
}
func (c *ClusterClient) loadState() (*clusterState, error) {
func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) {
if c.opt.ClusterSlots != nil {
slots, err := c.opt.ClusterSlots()
if err != nil {
@ -999,7 +1000,7 @@ func (c *ClusterClient) loadState() (*clusterState, error) {
continue
}
slots, err := node.Client.ClusterSlots().Result()
slots, err := node.Client.ClusterSlots(ctx).Result()
if err != nil {
if firstErr == nil {
firstErr = err
@ -1042,8 +1043,8 @@ func (c *ClusterClient) Pipeline() Pipeliner {
return &pipe
}
func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
@ -1052,7 +1053,7 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := newCmdsMap()
err := c.mapCmdsByNode(cmdsMap, cmds)
err := c.mapCmdsByNode(ctx, cmdsMap, cmds)
if err != nil {
setCmdsErr(cmds, err)
return err
@ -1079,7 +1080,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return
}
if attempt < c.opt.MaxRedirects {
if err := c.mapCmdsByNode(failedCmds, cmds); err != nil {
if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil {
setCmdsErr(cmds, err)
}
} else {
@ -1098,8 +1099,8 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return cmdsFirstErr(cmds)
}
func (c *ClusterClient) mapCmdsByNode(cmdsMap *cmdsMap, cmds []Cmder) error {
state, err := c.state.Get()
func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error {
state, err := c.state.Get(ctx)
if err != nil {
return err
}
@ -1150,21 +1151,25 @@ func (c *ClusterClient) _processPipelineNode(
}
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds)
})
})
})
}
func (c *ClusterClient) pipelineReadCmds(
node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context,
node *clusterNode,
rd *proto.Reader,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
for _, cmd := range cmds {
err := cmd.readReply(rd)
if err == nil {
continue
}
if c.checkMovedErr(cmd, err, failedCmds) {
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
continue
}
@ -1181,7 +1186,7 @@ func (c *ClusterClient) pipelineReadCmds(
}
func (c *ClusterClient) checkMovedErr(
cmd Cmder, err error, failedCmds *cmdsMap,
ctx context.Context, cmd Cmder, err error, failedCmds *cmdsMap,
) bool {
moved, ask, addr := isMovedError(err)
if !moved && !ask {
@ -1194,13 +1199,13 @@ func (c *ClusterClient) checkMovedErr(
}
if moved {
c.state.LazyReload()
c.state.LazyReload(ctx)
failedCmds.Add(node, cmd)
return true
}
if ask {
failedCmds.Add(node, NewCmd("asking"), cmd)
failedCmds.Add(node, NewCmd(ctx, "asking"), cmd)
return true
}
@ -1217,8 +1222,8 @@ func (c *ClusterClient) TxPipeline() Pipeliner {
return &pipe
}
func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
@ -1226,7 +1231,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err
}
func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
state, err := c.state.Get()
state, err := c.state.Get(ctx)
if err != nil {
setCmdsErr(cmds, err)
return err
@ -1262,7 +1267,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
return
}
if attempt < c.opt.MaxRedirects {
if err := c.mapCmdsByNode(failedCmds, cmds); err != nil {
if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil {
setCmdsErr(cmds, err)
}
} else {
@ -1308,11 +1313,11 @@ func (c *ClusterClient) _processTxPipelineNode(
// Trim multi and exec.
cmds = cmds[1 : len(cmds)-1]
err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds)
err := c.txPipelineReadQueued(ctx, rd, statusCmd, cmds, failedCmds)
if err != nil {
moved, ask, addr := isMovedError(err)
if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
return c.cmdsMoved(ctx, cmds, moved, ask, addr, failedCmds)
}
return err
}
@ -1324,7 +1329,11 @@ func (c *ClusterClient) _processTxPipelineNode(
}
func (c *ClusterClient) txPipelineReadQueued(
rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context,
rd *proto.Reader,
statusCmd *StatusCmd,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
// Parse queued replies.
if err := statusCmd.readReply(rd); err != nil {
@ -1333,7 +1342,7 @@ func (c *ClusterClient) txPipelineReadQueued(
for _, cmd := range cmds {
err := statusCmd.readReply(rd)
if err == nil || c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) {
if err == nil || c.checkMovedErr(ctx, cmd, err, failedCmds) || isRedisError(err) {
continue
}
return err
@ -1361,7 +1370,10 @@ func (c *ClusterClient) txPipelineReadQueued(
}
func (c *ClusterClient) cmdsMoved(
cmds []Cmder, moved, ask bool, addr string, failedCmds *cmdsMap,
ctx context.Context, cmds []Cmder,
moved, ask bool,
addr string,
failedCmds *cmdsMap,
) error {
node, err := c.nodes.Get(addr)
if err != nil {
@ -1369,7 +1381,7 @@ func (c *ClusterClient) cmdsMoved(
}
if moved {
c.state.LazyReload()
c.state.LazyReload(ctx)
for _, cmd := range cmds {
failedCmds.Add(node, cmd)
}
@ -1378,7 +1390,7 @@ func (c *ClusterClient) cmdsMoved(
if ask {
for _, cmd := range cmds {
failedCmds.Add(node, NewCmd("asking"), cmd)
failedCmds.Add(node, NewCmd(ctx, "asking"), cmd)
}
return nil
}
@ -1386,11 +1398,7 @@ func (c *ClusterClient) cmdsMoved(
return nil
}
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
return c.WatchContext(c.ctx, fn, keys...)
}
func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key")
}
@ -1403,7 +1411,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
}
}
node, err := c.slotMasterNode(slot)
node, err := c.slotMasterNode(ctx, slot)
if err != nil {
return err
}
@ -1415,12 +1423,12 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
}
}
err = node.Client.WatchContext(ctx, fn, keys...)
err = node.Client.Watch(ctx, fn, keys...)
if err == nil {
break
}
if err != Nil {
c.state.LazyReload()
c.state.LazyReload(ctx)
}
moved, ask, addr := isMovedError(err)
@ -1433,7 +1441,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
}
if err == pool.ErrClosed || isReadOnlyError(err) {
node, err = c.slotMasterNode(slot)
node, err = c.slotMasterNode(ctx, slot)
if err != nil {
return err
}
@ -1455,7 +1463,7 @@ func (c *ClusterClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@ -1463,7 +1471,7 @@ func (c *ClusterClient) pubSub() *PubSub {
var err error
if len(channels) > 0 {
slot := hashtag.Slot(channels[0])
node, err = c.slotMasterNode(slot)
node, err = c.slotMasterNode(ctx, slot)
} else {
node, err = c.nodes.Random()
}
@ -1493,20 +1501,20 @@ func (c *ClusterClient) pubSub() *PubSub {
// Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription.
func (c *ClusterClient) Subscribe(channels ...string) *PubSub {
func (c *ClusterClient) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.Subscribe(channels...)
_ = pubsub.Subscribe(ctx, channels...)
}
return pubsub
}
// PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription.
func (c *ClusterClient) PSubscribe(channels ...string) *PubSub {
func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...)
_ = pubsub.PSubscribe(ctx, channels...)
}
return pubsub
}
@ -1531,7 +1539,7 @@ func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
continue
}
info, err := node.Client.Command().Result()
info, err := node.Client.Command(context.TODO()).Result()
if err == nil {
return info, nil
}
@ -1573,8 +1581,12 @@ func cmdSlot(cmd Cmder, pos int) int {
return hashtag.Slot(firstKey)
}
func (c *ClusterClient) cmdNode(cmdInfo *CommandInfo, slot int) (*clusterNode, error) {
state, err := c.state.Get()
func (c *ClusterClient) cmdNode(
ctx context.Context,
cmdInfo *CommandInfo,
slot int,
) (*clusterNode, error) {
state, err := c.state.Get(ctx)
if err != nil {
return nil, err
}
@ -1595,8 +1607,8 @@ func (c *clusterClient) slotReadOnlyNode(state *clusterState, slot int) (*cluste
return state.slotSlaveNode(slot)
}
func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) {
state, err := c.state.Get()
func (c *ClusterClient) slotMasterNode(ctx context.Context, slot int) (*clusterNode, error) {
state, err := c.state.Get(ctx)
if err != nil {
return nil, err
}