mirror of
https://github.com/redis/go-redis.git
synced 2025-07-29 17:41:15 +03:00
Rework PubSub conn management
This commit is contained in:
152
pubsub.go
152
pubsub.go
@ -20,49 +20,14 @@ type PubSub struct {
|
||||
cn *pool.Conn
|
||||
closed bool
|
||||
|
||||
cmd *Cmd
|
||||
|
||||
subMu sync.Mutex
|
||||
channels []string
|
||||
patterns []string
|
||||
|
||||
cmd *Cmd
|
||||
}
|
||||
|
||||
func (c *PubSub) conn() (*pool.Conn, error) {
|
||||
cn, isNew, err := c._conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isNew {
|
||||
if err := c.resubscribe(); err != nil {
|
||||
internal.Logf("resubscribe failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *PubSub) resubscribe() error {
|
||||
c.subMu.Lock()
|
||||
channels := c.channels
|
||||
patterns := c.patterns
|
||||
c.subMu.Unlock()
|
||||
|
||||
var firstErr error
|
||||
if len(channels) > 0 {
|
||||
if err := c.subscribe("subscribe", channels...); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if len(patterns) > 0 {
|
||||
if err := c.subscribe("psubscribe", patterns...); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (c *PubSub) _conn() (*pool.Conn, bool, error) {
|
||||
func (c *PubSub) conn() (*pool.Conn, bool, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@ -86,21 +51,81 @@ func (c *PubSub) _conn() (*pool.Conn, bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.resubscribe(cn); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
c.cn = cn
|
||||
return cn, true, nil
|
||||
}
|
||||
|
||||
func (c *PubSub) putConn(cn *pool.Conn, err error) {
|
||||
if internal.IsBadConn(err, true) {
|
||||
c.mu.Lock()
|
||||
if c.cn == cn {
|
||||
_ = c.closeConn()
|
||||
func (c *PubSub) resubscribe(cn *pool.Conn) error {
|
||||
c.subMu.Lock()
|
||||
defer c.subMu.Unlock()
|
||||
|
||||
var firstErr error
|
||||
if len(c.channels) > 0 {
|
||||
if err := c._subscribe(cn, "subscribe", c.channels...); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
if len(c.patterns) > 0 {
|
||||
if err := c._subscribe(cn, "psubscribe", c.patterns...); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (c *PubSub) putConn(cn *pool.Conn, err error) {
|
||||
if !internal.IsBadConn(err, true) {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.cn == cn {
|
||||
_ = c.closeConn()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *PubSub) closeConn() error {
|
||||
err := c.base.connPool.CloseConn(c.cn)
|
||||
c.cn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *PubSub) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return pool.ErrClosed
|
||||
}
|
||||
c.closed = true
|
||||
|
||||
if c.cn != nil {
|
||||
return c.closeConn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
|
||||
cn, isNew, err := c.conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isNew {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = c._subscribe(cn, redisCmd, channels...)
|
||||
c.putConn(cn, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error {
|
||||
args := make([]interface{}, 1+len(channels))
|
||||
args[0] = redisCmd
|
||||
for i, channel := range channels {
|
||||
@ -108,19 +133,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
|
||||
}
|
||||
cmd := NewSliceCmd(args...)
|
||||
|
||||
cn, isNew, err := c._conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isNew {
|
||||
return c.resubscribe()
|
||||
}
|
||||
|
||||
cn.SetWriteTimeout(c.base.opt.WriteTimeout)
|
||||
err = writeCmd(cn, cmd)
|
||||
c.putConn(cn, err)
|
||||
return err
|
||||
return writeCmd(cn, cmd)
|
||||
}
|
||||
|
||||
// Subscribes the client to the specified channels.
|
||||
@ -157,28 +171,6 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
|
||||
return c.subscribe("punsubscribe", patterns...)
|
||||
}
|
||||
|
||||
func (c *PubSub) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return pool.ErrClosed
|
||||
}
|
||||
c.closed = true
|
||||
|
||||
if c.cn != nil {
|
||||
_ = c.closeConn()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PubSub) closeConn() error {
|
||||
err := c.base.connPool.CloseConn(c.cn)
|
||||
c.cn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *PubSub) Ping(payload ...string) error {
|
||||
args := []interface{}{"ping"}
|
||||
if len(payload) == 1 {
|
||||
@ -186,7 +178,7 @@ func (c *PubSub) Ping(payload ...string) error {
|
||||
}
|
||||
cmd := NewCmd(args...)
|
||||
|
||||
cn, err := c.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -279,7 +271,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
|
||||
c.cmd = NewCmd()
|
||||
}
|
||||
|
||||
cn, err := c.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user