diff --git a/cluster.go b/cluster.go index c33b5bcb..72d65df9 100644 --- a/cluster.go +++ b/cluster.go @@ -1012,7 +1012,7 @@ func (c *ClusterClient) reaper(idleCheckFrequency time.Duration) { for _, node := range nodes { _, err := node.Client.connPool.(*pool.ConnPool).ReapStaleConns() if err != nil { - internal.Logf("ReapStaleConns failed: %s", err) + internal.Logger.Printf("ReapStaleConns failed: %s", err) } } } @@ -1524,7 +1524,7 @@ func (c *ClusterClient) cmdInfo(name string) *CommandInfo { info := cmdsInfo[name] if info == nil { - internal.Logf("info for cmd=%s not found", name) + internal.Logger.Printf("info for cmd=%s not found", name) } return info } diff --git a/commands.go b/commands.go index aa224b57..3ce3a805 100644 --- a/commands.go +++ b/commands.go @@ -14,7 +14,7 @@ func usePrecise(dur time.Duration) bool { func formatMs(dur time.Duration) int64 { if dur > 0 && dur < time.Millisecond { - internal.Logf( + internal.Logger.Printf( "specified duration is %s, but minimal supported value is %s", dur, time.Millisecond, ) @@ -24,7 +24,7 @@ func formatMs(dur time.Duration) int64 { func formatSec(dur time.Duration) int64 { if dur > 0 && dur < time.Second { - internal.Logf( + internal.Logger.Printf( "specified duration is %s, but minimal supported value is %s", dur, time.Second, ) diff --git a/internal/log.go b/internal/log.go index fd14222e..405a2728 100644 --- a/internal/log.go +++ b/internal/log.go @@ -1,15 +1,8 @@ package internal import ( - "fmt" "log" + "os" ) -var Logger *log.Logger - -func Logf(s string, args ...interface{}) { - if Logger == nil { - return - } - Logger.Output(2, fmt.Sprintf(s, args...)) -} +var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index d18b2566..dacf7be8 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -139,10 +139,11 @@ func (p *ConnPool) _NewConn(ctx context.Context, pooled bool) (*Conn, error) { p.connsMu.Lock() p.conns = append(p.conns, cn) if pooled { - if p.poolSize < p.opt.PoolSize { - p.poolSize++ - } else { + // If pool is full remove the cn on next Put. + if p.poolSize >= p.opt.PoolSize { cn.pooled = false + } else { + p.poolSize++ } } p.connsMu.Unlock() @@ -315,18 +316,23 @@ func (p *ConnPool) Put(cn *Conn) { } func (p *ConnPool) Remove(cn *Conn) { - p.removeConn(cn) + p.removeConnWithLock(cn) p.freeTurn() _ = p.closeConn(cn) } func (p *ConnPool) CloseConn(cn *Conn) error { - p.removeConn(cn) + p.removeConnWithLock(cn) return p.closeConn(cn) } -func (p *ConnPool) removeConn(cn *Conn) { +func (p *ConnPool) removeConnWithLock(cn *Conn) { p.connsMu.Lock() + p.removeConn(cn) + p.connsMu.Unlock() +} + +func (p *ConnPool) removeConn(cn *Conn) { for i, c := range p.conns { if c == cn { p.conns = append(p.conns[:i], p.conns[i+1:]...) @@ -334,10 +340,9 @@ func (p *ConnPool) removeConn(cn *Conn) { p.poolSize-- p.checkMinIdleConns() } - break + return } } - p.connsMu.Unlock() } func (p *ConnPool) closeConn(cn *Conn) error { @@ -415,20 +420,21 @@ func (p *ConnPool) Close() error { return firstErr } -func (p *ConnPool) reapStaleConn() *Conn { - if len(p.idleConns) == 0 { - return nil +func (p *ConnPool) reaper(frequency time.Duration) { + ticker := time.NewTicker(frequency) + defer ticker.Stop() + + for range ticker.C { + if p.closed() { + break + } + n, err := p.ReapStaleConns() + if err != nil { + internal.Logger.Printf("ReapStaleConns failed: %s", err) + continue + } + atomic.AddUint32(&p.stats.StaleConns, uint32(n)) } - - cn := p.idleConns[0] - if !p.isStaleConn(cn) { - return nil - } - - p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) - p.idleConnsLen-- - - return cn } func (p *ConnPool) ReapStaleConns() (int, error) { @@ -439,11 +445,6 @@ func (p *ConnPool) ReapStaleConns() (int, error) { p.connsMu.Lock() cn := p.reapStaleConn() p.connsMu.Unlock() - - if cn != nil { - p.removeConn(cn) - } - p.freeTurn() if cn != nil { @@ -456,21 +457,21 @@ func (p *ConnPool) ReapStaleConns() (int, error) { return n, nil } -func (p *ConnPool) reaper(frequency time.Duration) { - ticker := time.NewTicker(frequency) - defer ticker.Stop() - - for range ticker.C { - if p.closed() { - break - } - n, err := p.ReapStaleConns() - if err != nil { - internal.Logf("ReapStaleConns failed: %s", err) - continue - } - atomic.AddUint32(&p.stats.StaleConns, uint32(n)) +func (p *ConnPool) reapStaleConn() *Conn { + if len(p.idleConns) == 0 { + return nil } + + cn := p.idleConns[0] + if !p.isStaleConn(cn) { + return nil + } + + p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) + p.idleConnsLen-- + p.removeConn(cn) + + return cn } func (p *ConnPool) isStaleConn(cn *Conn) bool { diff --git a/pubsub.go b/pubsub.go index e3df4f40..7bb3872a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -52,14 +52,14 @@ func (c *PubSub) init() { c.exit = make(chan struct{}) } -func (c *PubSub) conn() (*pool.Conn, error) { +func (c *PubSub) connWithLock() (*pool.Conn, error) { c.mu.Lock() - cn, err := c._conn(nil) + cn, err := c.conn(nil) c.mu.Unlock() return cn, err } -func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { +func (c *PubSub) conn(newChannels []string) (*pool.Conn, error) { if c.closed { return nil, pool.ErrClosed } @@ -132,32 +132,32 @@ func (c *PubSub) _subscribe( return c.writeCmd(context.TODO(), cn, cmd) } -func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { +func (c *PubSub) releaseConnWithLock(cn *pool.Conn, err error, allowTimeout bool) { c.mu.Lock() - c._releaseConn(cn, err, allowTimeout) + c.releaseConn(cn, err, allowTimeout) c.mu.Unlock() } -func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) { +func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { if c.cn != cn { return } if internal.IsBadConn(err, allowTimeout) { - c._reconnect(err) + c.reconnect(err) } } -func (c *PubSub) _reconnect(reason error) { - _ = c._closeTheCn(reason) - _, _ = c._conn(nil) +func (c *PubSub) reconnect(reason error) { + _ = c.closeTheCn(reason) + _, _ = c.conn(nil) } -func (c *PubSub) _closeTheCn(reason error) error { +func (c *PubSub) closeTheCn(reason error) error { if c.cn == nil { return nil } if !c.closed { - internal.Logf("redis: discarding bad PubSub connection: %s", reason) + internal.Logger.Printf("redis: discarding bad PubSub connection: %s", reason) } err := c.closeConn(c.cn) c.cn = nil @@ -174,8 +174,7 @@ func (c *PubSub) Close() error { c.closed = true close(c.exit) - err := c._closeTheCn(pool.ErrClosed) - return err + return c.closeTheCn(pool.ErrClosed) } // Subscribe the client to the specified channels. It returns @@ -237,13 +236,13 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, err := c._conn(channels) + cn, err := c.conn(channels) if err != nil { return err } err = c._subscribe(cn, redisCmd, channels) - c._releaseConn(cn, err, false) + c.releaseConn(cn, err, false) return err } @@ -254,13 +253,13 @@ func (c *PubSub) Ping(payload ...string) error { } cmd := NewCmd(args...) - cn, err := c.conn() + cn, err := c.connWithLock() if err != nil { return err } err = c.writeCmd(context.TODO(), cn, cmd) - c.releaseConn(cn, err, false) + c.releaseConnWithLock(cn, err, false) return err } @@ -346,7 +345,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.cmd = NewCmd() } - cn, err := c.conn() + cn, err := c.connWithLock() if err != nil { return nil, err } @@ -355,7 +354,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return c.cmd.readReply(rd) }) - c.releaseConn(cn, err, timeout > 0) + c.releaseConnWithLock(cn, err, timeout > 0) if err != nil { return nil, err } @@ -467,12 +466,12 @@ func (c *PubSub) initChannel(size int) { <-timer.C } case <-timer.C: - internal.Logf( + internal.Logger.Printf( "redis: %s channel is full for %s (message is dropped)", c, timeout) } default: - internal.Logf("redis: unknown message type: %T", msg) + internal.Logger.Printf("redis: unknown message type: %T", msg) } } }() @@ -499,7 +498,7 @@ func (c *PubSub) initChannel(size int) { pingErr = errPingTimeout } c.mu.Lock() - c._reconnect(pingErr) + c.reconnect(pingErr) c.mu.Unlock() } case <-c.exit: diff --git a/redis.go b/redis.go index 54b85411..d9db7e10 100644 --- a/redis.go +++ b/redis.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log" - "os" "time" "github.com/go-redis/redis/internal" @@ -15,10 +14,6 @@ import ( // Nil reply Redis returns when key does not exist. const Nil = proto.Nil -func init() { - SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)) -} - func SetLogger(logger *log.Logger) { internal.Logger = logger } diff --git a/ring.go b/ring.go index 3e9e0c49..93066ffa 100644 --- a/ring.go +++ b/ring.go @@ -260,7 +260,7 @@ func (c *ringShards) Heartbeat(frequency time.Duration) { for _, shard := range shards { err := shard.Client.Ping().Err() if shard.Vote(err == nil || err == pool.ErrPoolTimeout) { - internal.Logf("ring shard state changed: %s", shard) + internal.Logger.Printf("ring shard state changed: %s", shard) rebalance = true } } @@ -525,7 +525,7 @@ func (c *Ring) cmdInfo(name string) *CommandInfo { } info := cmdsInfo[name] if info == nil { - internal.Logf("info for cmd=%s not found", name) + internal.Logger.Printf("info for cmd=%s not found", name) } return info } diff --git a/sentinel.go b/sentinel.go index 331d0048..1f906d66 100644 --- a/sentinel.go +++ b/sentinel.go @@ -360,7 +360,7 @@ func (c *sentinelFailover) masterAddr() (string, error) { masterAddr, err := sentinel.GetMasterAddrByName(c.masterName).Result() if err != nil { - internal.Logf("sentinel: GetMasterAddrByName master=%q failed: %s", + internal.Logger.Printf("sentinel: GetMasterAddrByName master=%q failed: %s", c.masterName, err) _ = sentinel.Close() continue @@ -388,7 +388,7 @@ func (c *sentinelFailover) getMasterAddr() string { addr, err := sentinel.GetMasterAddrByName(c.masterName).Result() if err != nil { - internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", + internal.Logger.Printf("sentinel: GetMasterAddrByName name=%q failed: %s", c.masterName, err) c.mu.Lock() if c.sentinel == sentinel { @@ -412,7 +412,7 @@ func (c *sentinelFailover) switchMaster(addr string) { c.mu.Lock() defer c.mu.Unlock() - internal.Logf("sentinel: new master=%q addr=%q", + internal.Logger.Printf("sentinel: new master=%q addr=%q", c.masterName, addr) _ = c.Pool().Filter(func(cn *pool.Conn) bool { return cn.RemoteAddr().String() != addr @@ -449,7 +449,7 @@ func (c *sentinelFailover) closeSentinel() error { func (c *sentinelFailover) discoverSentinels(sentinel *SentinelClient) { sentinels, err := sentinel.Sentinels(c.masterName).Result() if err != nil { - internal.Logf("sentinel: Sentinels master=%q failed: %s", c.masterName, err) + internal.Logger.Printf("sentinel: Sentinels master=%q failed: %s", c.masterName, err) return } for _, sentinel := range sentinels { @@ -459,7 +459,7 @@ func (c *sentinelFailover) discoverSentinels(sentinel *SentinelClient) { if key == "name" { sentinelAddr := vals[i+1].(string) if !contains(c.sentinelAddrs, sentinelAddr) { - internal.Logf("sentinel: discovered new sentinel=%q for master=%q", + internal.Logger.Printf("sentinel: discovered new sentinel=%q for master=%q", sentinelAddr, c.masterName) c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) } @@ -479,7 +479,7 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { if msg.Channel == "+switch-master" { parts := strings.Split(msg.Payload, " ") if parts[0] != c.masterName { - internal.Logf("sentinel: ignore addr for master=%q", parts[0]) + internal.Logger.Printf("sentinel: ignore addr for master=%q", parts[0]) continue } addr := net.JoinHostPort(parts[3], parts[4])