1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-29 17:41:15 +03:00

Fix WithContext and add tests

This commit is contained in:
Vladimir Mihailenco
2019-07-04 11:18:06 +03:00
parent 73d3c18522
commit 2cbb5194fb
14 changed files with 114 additions and 90 deletions

View File

@ -1,6 +1,7 @@
package pool_test
import (
"context"
"fmt"
"testing"
"time"
@ -39,7 +40,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
@ -81,7 +82,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(context.Background())
if err != nil {
b.Fatal(err)
}

View File

@ -250,38 +250,38 @@ func (p *ConnPool) getTurn() {
}
func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{}
if ctx != nil {
done = ctx.Done()
select {
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
return nil
default:
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)
}
select {
case <-done:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
}
}

View File

@ -1,6 +1,7 @@
package pool_test
import (
"context"
"sync"
"testing"
"time"
@ -12,6 +13,7 @@ import (
)
var _ = Describe("ConnPool", func() {
c := context.Background()
var connPool *pool.ConnPool
BeforeEach(func() {
@ -30,13 +32,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
@ -47,7 +49,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()
started <- true
_, err := connPool.Get(nil)
_, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
done <- true
@ -80,6 +82,7 @@ var _ = Describe("ConnPool", func() {
})
var _ = Describe("MinIdleConns", func() {
c := context.Background()
const poolSize = 100
var minIdleConns int
var connPool *pool.ConnPool
@ -110,7 +113,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
var err error
cn, err = connPool.Get(nil)
cn, err = connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Eventually(func() int {
@ -145,7 +148,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) {
defer GinkgoRecover()
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
mu.Lock()
cns = append(cns, cn)
@ -160,7 +163,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() {
done := make(chan struct{})
go func() {
connPool.Get(nil)
connPool.Get(c)
close(done)
}()
@ -247,6 +250,8 @@ var _ = Describe("MinIdleConns", func() {
})
var _ = Describe("conns reaper", func() {
c := context.Background()
const idleTimeout = time.Minute
const maxAge = time.Hour
@ -274,7 +279,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
staleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
switch typ {
case "idle":
@ -288,7 +293,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
@ -333,7 +338,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
@ -342,7 +347,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0))
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
@ -370,6 +375,7 @@ var _ = Describe("conns reaper", func() {
})
var _ = Describe("race", func() {
c := context.Background()
var connPool *pool.ConnPool
var C, N int
@ -396,7 +402,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Put(cn)
@ -404,7 +410,7 @@ var _ = Describe("race", func() {
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(cn)