From 20363d149b416d974462deff808f0499e0ea0d13 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 13 Mar 2018 15:51:38 +0200 Subject: [PATCH] Fix WithContext followed by WrapProcess --- redis.go | 52 ++++++++++++++++++++++++++++----------------------- redis_test.go | 32 ++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/redis.go b/redis.go index e0b64644..7a606b70 100644 --- a/redis.go +++ b/redis.go @@ -96,16 +96,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return nil } - // Temp client to initialize connection. - conn := &Conn{ - baseClient: baseClient{ - opt: c.opt, - connPool: pool.NewSingleConnPool(cn), - }, - } - conn.baseClient.init() - conn.statefulCmdable.setProcessor(conn.Process) - + conn := newConn(c.opt, cn) _, err := conn.Pipelined(func(pipe Pipeliner) error { if c.opt.Password != "" { pipe.Auth(c.opt.Password) @@ -351,22 +342,24 @@ type Client struct { ctx context.Context } -func newClient(opt *Options, pool pool.Pooler) *Client { - c := Client{ - baseClient: baseClient{ - opt: opt, - connPool: pool, - }, - } - c.baseClient.init() - c.cmdable.setProcessor(c.Process) - return &c -} - // NewClient returns a client to the Redis Server specified by Options. func NewClient(opt *Options) *Client { opt.init() - return newClient(opt, newConnPool(opt)) + + c := Client{ + baseClient: baseClient{ + opt: opt, + connPool: newConnPool(opt), + }, + } + c.baseClient.init() + c.init() + + return &c +} + +func (c *Client) init() { + c.cmdable.setProcessor(c.Process) } func (c *Client) Context() context.Context { @@ -387,6 +380,7 @@ func (c *Client) WithContext(ctx context.Context) *Client { func (c *Client) copy() *Client { cp := *c + cp.init() return &cp } @@ -467,6 +461,18 @@ type Conn struct { statefulCmdable } +func newConn(opt *Options, cn *pool.Conn) *Conn { + c := Conn{ + baseClient: baseClient{ + opt: opt, + connPool: pool.NewSingleConnPool(cn), + }, + } + c.baseClient.init() + c.statefulCmdable.setProcessor(c.Process) + return &c +} + func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(fn) } diff --git a/redis_test.go b/redis_test.go index ad8aa11a..df2485d3 100644 --- a/redis_test.go +++ b/redis_test.go @@ -228,18 +228,40 @@ var _ = Describe("Client", func() { }) It("should call WrapProcess", func() { - var wrapperFnCalled bool + var fnCalled bool - client.WrapProcess(func(oldProcess func(redis.Cmder) error) func(redis.Cmder) error { + client.WrapProcess(func(old func(redis.Cmder) error) func(redis.Cmder) error { return func(cmd redis.Cmder) error { - wrapperFnCalled = true - return oldProcess(cmd) + fnCalled = true + return old(cmd) } }) Expect(client.Ping().Err()).NotTo(HaveOccurred()) + Expect(fnCalled).To(BeTrue()) + }) - Expect(wrapperFnCalled).To(BeTrue()) + It("should call WrapProcess after WithContext", func() { + var fn1Called, fn2Called bool + + client.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { + return func(cmd redis.Cmder) error { + fn1Called = true + return old(cmd) + } + }) + + client2 := client.WithContext(client.Context()) + client2.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { + return func(cmd redis.Cmder) error { + fn2Called = true + return old(cmd) + } + }) + + Expect(client2.Ping().Err()).NotTo(HaveOccurred()) + Expect(fn2Called).To(BeTrue()) + Expect(fn1Called).To(BeTrue()) }) })