1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-29 17:41:15 +03:00

Fix race in PubSub

This commit is contained in:
Vladimir Mihailenco
2017-06-29 17:05:08 +03:00
parent fbc8000fd1
commit 0d94a7bc88
4 changed files with 25 additions and 51 deletions

View File

@ -238,30 +238,4 @@ var _ = Describe("race", func() {
} }
}) })
}) })
It("does not happen on Get and PopFree", func() {
connPool = pool.NewConnPool(
&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolTimeout: time.Minute,
IdleTimeout: time.Second,
IdleCheckFrequency: time.Millisecond,
})
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
}
cn = connPool.PopFree()
if cn != nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
}
}
})
})
}) })

View File

@ -3,7 +3,6 @@ package redis_test
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -52,7 +51,7 @@ var cluster = &clusterScenario{
} }
func init() { func init() {
redis.SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)) //redis.SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile))
} }
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {

View File

@ -28,37 +28,41 @@ type PubSub struct {
cmd *Cmd cmd *Cmd
} }
func (c *PubSub) conn() (*pool.Conn, bool, error) { func (c *PubSub) conn() (*pool.Conn, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() cn, err := c._conn()
c.mu.Unlock()
return cn, err
}
func (c *PubSub) _conn() (*pool.Conn, error) {
if c.closed { if c.closed {
return nil, false, pool.ErrClosed return nil, pool.ErrClosed
} }
if c.cn != nil { if c.cn != nil {
return c.cn, false, nil return c.cn, nil
} }
cn, err := c.base.connPool.NewConn() cn, err := c.base.connPool.NewConn()
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
if !cn.Inited { if !cn.Inited {
if err := c.base.initConn(cn); err != nil { if err := c.base.initConn(cn); err != nil {
_ = c.base.connPool.CloseConn(cn) _ = c.base.connPool.CloseConn(cn)
return nil, false, err return nil, err
} }
} }
if err := c.resubscribe(cn); err != nil { if err := c.resubscribe(cn); err != nil {
_ = c.base.connPool.CloseConn(cn) _ = c.base.connPool.CloseConn(cn)
return nil, false, err return nil, err
} }
c.cn = cn c.cn = cn
return cn, true, nil return cn, nil
} }
func (c *PubSub) resubscribe(cn *pool.Conn) error { func (c *PubSub) resubscribe(cn *pool.Conn) error {
@ -125,48 +129,48 @@ func (c *PubSub) Close() error {
// empty subscription if there are no channels. // empty subscription if there are no channels.
func (c *PubSub) Subscribe(channels ...string) error { func (c *PubSub) Subscribe(channels ...string) error {
c.mu.Lock() c.mu.Lock()
err := c.subscribe("subscribe", channels...)
c.channels = appendIfNotExists(c.channels, channels...) c.channels = appendIfNotExists(c.channels, channels...)
c.mu.Unlock() c.mu.Unlock()
return c.subscribe("subscribe", channels...) return err
} }
// Subscribes the client to the given patterns. It returns // Subscribes the client to the given patterns. It returns
// empty subscription if there are no patterns. // empty subscription if there are no patterns.
func (c *PubSub) PSubscribe(patterns ...string) error { func (c *PubSub) PSubscribe(patterns ...string) error {
c.mu.Lock() c.mu.Lock()
err := c.subscribe("psubscribe", patterns...)
c.patterns = appendIfNotExists(c.patterns, patterns...) c.patterns = appendIfNotExists(c.patterns, patterns...)
c.mu.Unlock() c.mu.Unlock()
return c.subscribe("psubscribe", patterns...) return err
} }
// Unsubscribes the client from the given channels, or from all of // Unsubscribes the client from the given channels, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error { func (c *PubSub) Unsubscribe(channels ...string) error {
c.mu.Lock() c.mu.Lock()
err := c.subscribe("unsubscribe", channels...)
c.channels = remove(c.channels, channels...) c.channels = remove(c.channels, channels...)
c.mu.Unlock() c.mu.Unlock()
return c.subscribe("unsubscribe", channels...) return err
} }
// Unsubscribes the client from the given patterns, or from all of // Unsubscribes the client from the given patterns, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error { func (c *PubSub) PUnsubscribe(patterns ...string) error {
c.mu.Lock() c.mu.Lock()
err := c.subscribe("punsubscribe", patterns...)
c.patterns = remove(c.patterns, patterns...) c.patterns = remove(c.patterns, patterns...)
c.mu.Unlock() c.mu.Unlock()
return c.subscribe("punsubscribe", patterns...) return err
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, isNew, err := c.conn() cn, err := c._conn()
if err != nil { if err != nil {
return err return err
} }
if isNew {
return nil
}
err = c._subscribe(cn, redisCmd, channels...) err = c._subscribe(cn, redisCmd, channels...)
c.putConn(cn, err) c.putConn(cn, err)
return err return err
@ -179,7 +183,7 @@ func (c *PubSub) Ping(payload ...string) error {
} }
cmd := NewCmd(args...) cmd := NewCmd(args...)
cn, _, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -272,7 +276,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
c.cmd = NewCmd() c.cmd = NewCmd()
} }
cn, _, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -387,10 +387,7 @@ func (c *Client) pubSub() *PubSub {
func (c *Client) Subscribe(channels ...string) *PubSub { func (c *Client) Subscribe(channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
err := pubsub.Subscribe(channels...) _ = pubsub.Subscribe(channels...)
if err != nil {
panic(err)
}
} }
return pubsub return pubsub
} }