1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

Fix race in PubSub

This commit is contained in:
Vladimir Mihailenco
2017-06-29 17:05:08 +03:00
parent fbc8000fd1
commit 0d94a7bc88
4 changed files with 25 additions and 51 deletions

View File

@ -28,37 +28,41 @@ type PubSub struct {
cmd *Cmd
}
func (c *PubSub) conn() (*pool.Conn, bool, error) {
func (c *PubSub) conn() (*pool.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
cn, err := c._conn()
c.mu.Unlock()
return cn, err
}
func (c *PubSub) _conn() (*pool.Conn, error) {
if c.closed {
return nil, false, pool.ErrClosed
return nil, pool.ErrClosed
}
if c.cn != nil {
return c.cn, false, nil
return c.cn, nil
}
cn, err := c.base.connPool.NewConn()
if err != nil {
return nil, false, err
return nil, err
}
if !cn.Inited {
if err := c.base.initConn(cn); err != nil {
_ = c.base.connPool.CloseConn(cn)
return nil, false, err
return nil, err
}
}
if err := c.resubscribe(cn); err != nil {
_ = c.base.connPool.CloseConn(cn)
return nil, false, err
return nil, err
}
c.cn = cn
return cn, true, nil
return cn, nil
}
func (c *PubSub) resubscribe(cn *pool.Conn) error {
@ -125,48 +129,48 @@ func (c *PubSub) Close() error {
// empty subscription if there are no channels.
func (c *PubSub) Subscribe(channels ...string) error {
c.mu.Lock()
err := c.subscribe("subscribe", channels...)
c.channels = appendIfNotExists(c.channels, channels...)
c.mu.Unlock()
return c.subscribe("subscribe", channels...)
return err
}
// Subscribes the client to the given patterns. It returns
// empty subscription if there are no patterns.
func (c *PubSub) PSubscribe(patterns ...string) error {
c.mu.Lock()
err := c.subscribe("psubscribe", patterns...)
c.patterns = appendIfNotExists(c.patterns, patterns...)
c.mu.Unlock()
return c.subscribe("psubscribe", patterns...)
return err
}
// Unsubscribes the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
c.mu.Lock()
err := c.subscribe("unsubscribe", channels...)
c.channels = remove(c.channels, channels...)
c.mu.Unlock()
return c.subscribe("unsubscribe", channels...)
return err
}
// Unsubscribes the client from the given patterns, or from all of
// them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error {
c.mu.Lock()
err := c.subscribe("punsubscribe", patterns...)
c.patterns = remove(c.patterns, patterns...)
c.mu.Unlock()
return c.subscribe("punsubscribe", patterns...)
return err
}
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, isNew, err := c.conn()
cn, err := c._conn()
if err != nil {
return err
}
if isNew {
return nil
}
err = c._subscribe(cn, redisCmd, channels...)
c.putConn(cn, err)
return err
@ -179,7 +183,7 @@ func (c *PubSub) Ping(payload ...string) error {
}
cmd := NewCmd(args...)
cn, _, err := c.conn()
cn, err := c.conn()
if err != nil {
return err
}
@ -272,7 +276,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
c.cmd = NewCmd()
}
cn, _, err := c.conn()
cn, err := c.conn()
if err != nil {
return nil, err
}