diff --git a/redis.go b/redis.go index 472b3247..e15da91e 100644 --- a/redis.go +++ b/redis.go @@ -49,13 +49,13 @@ func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { if len(hs.hooks) == 0 { - return hs.withContext(ctx, func() error { - err := fn(ctx, cmd) - if err != nil { - cmd.SetErr(err) - } - return err + err, canceled := hs.withContext(ctx, func() error { + return fn(ctx, cmd) }) + if canceled { + cmd.SetErr(err) + } + return err } var hookIndex int @@ -69,13 +69,13 @@ func (hs hooks) process( } if retErr == nil { - retErr = hs.withContext(ctx, func() error { - err := fn(ctx, cmd) - if err != nil { - cmd.SetErr(err) - } - return err + var canceled bool + retErr, canceled = hs.withContext(ctx, func() error { + return fn(ctx, cmd) }) + if canceled { + cmd.SetErr(retErr) + } } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -92,13 +92,13 @@ func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { if len(hs.hooks) == 0 { - return hs.withContext(ctx, func() error { - err := fn(ctx, cmds) - if err != nil { - setCmdsErr(cmds, err) - } - return err + err, canceled := hs.withContext(ctx, func() error { + return fn(ctx, cmds) }) + if canceled { + setCmdsErr(cmds, err) + } + return err } var hookIndex int @@ -112,13 +112,13 @@ func (hs hooks) processPipeline( } if retErr == nil { - retErr = hs.withContext(ctx, func() error { - err := fn(ctx, cmds) - if err != nil { - setCmdsErr(cmds, err) - } - return err + var canceled bool + retErr, canceled = hs.withContext(ctx, func() error { + return fn(ctx, cmds) }) + if canceled { + setCmdsErr(cmds, retErr) + } } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -138,19 +138,20 @@ func (hs hooks) processTxPipeline( return hs.processPipeline(ctx, cmds, fn) } -func (hs hooks) withContext(ctx context.Context, fn func() error) error { - if ctx.Done() == nil { - return fn() +func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canceled bool) { + done := ctx.Done() + if done == nil { + return fn(), false } errc := make(chan error, 1) go func() { errc <- fn() }() select { - case <-ctx.Done(): - return ctx.Err() + case <-done: + return ctx.Err(), true case err := <-errc: - return err + return err, false } }