1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-29 17:41:15 +03:00

Add connnection pool and improve API.

This commit is contained in:
Vladimir Mihailenco
2012-08-05 15:09:43 +03:00
parent 859c5fb03b
commit 19a5db6632
7 changed files with 962 additions and 761 deletions

189
redis.go
View File

@ -1,145 +1,156 @@
package redis
import (
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"github.com/vmihailenco/bufreader"
)
type connectFunc func() (io.ReadWriter, error)
type disconnectFunc func(io.ReadWriter)
type OpenConnFunc func() (io.ReadWriter, error)
type CloseConnFunc func(io.ReadWriter)
type InitConnFunc func(*Client) error
func TCPConnector(addr string) OpenConnFunc {
return func() (io.ReadWriter, error) {
return net.Dial("tcp", addr)
}
}
func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc {
return func() (io.ReadWriter, error) {
return tls.Dial("tcp", addr, tlsConfig)
}
}
func AuthSelectFunc(password string, db int64) InitConnFunc {
return func(client *Client) error {
if password != "" {
_, err := client.Auth(password).Reply()
if err != nil {
return err
}
}
_, err := client.Select(db).Reply()
if err != nil {
return err
}
return nil
}
}
func createReader() (*bufreader.Reader, error) {
return bufreader.NewSizedReader(8192), nil
}
type Client struct {
mtx sync.Mutex
connect connectFunc
disconnect disconnectFunc
currConn io.ReadWriter
readerPool *bufreader.ReaderPool
mtx sync.Mutex
ConnPool *ConnPool
InitConn InitConnFunc
reqs []Req
}
func NewClient(connect connectFunc, disconnect disconnectFunc) *Client {
func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client {
return &Client{
readerPool: bufreader.NewReaderPool(100, createReader),
connect: connect,
disconnect: disconnect,
reqs: make([]Req, 0),
ConnPool: NewConnPool(openConn, closeConn, 10),
InitConn: initConn,
}
}
func NewMultiClient(connect connectFunc, disconnect disconnectFunc) *Client {
return &Client{
readerPool: bufreader.NewReaderPool(100, createReader),
connect: connect,
disconnect: disconnect,
reqs: make([]Req, 0),
}
func NewTCPClient(addr string, password string, db int64) *Client {
return NewClient(TCPConnector(addr), nil, AuthSelectFunc(password, db))
}
func (c *Client) Close() error {
if c.disconnect != nil {
c.disconnect(c.currConn)
}
c.currConn = nil
return nil
func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64) *Client {
return NewClient(
TLSConnector(addr, tlsConfig),
nil,
AuthSelectFunc(password, db),
)
}
func (c *Client) conn() (io.ReadWriter, error) {
if c.currConn == nil {
currConn, err := c.connect()
if err != nil {
return nil, err
}
c.currConn = currConn
}
return c.currConn, nil
}
func (c *Client) WriteReq(buf []byte) error {
conn, err := c.conn()
if err != nil {
return err
}
_, err = conn.Write(buf)
if err != nil {
c.Close()
}
func (c *Client) WriteReq(buf []byte, conn *Conn) error {
_, err := conn.RW.Write(buf)
return err
}
func (c *Client) ReadReply(rd *bufreader.Reader) error {
conn, err := c.conn()
func (c *Client) ReadReply(conn *Conn) error {
_, err := conn.Rd.ReadFrom(conn.RW)
if err != nil {
return err
}
_, err = rd.ReadFrom(conn)
if err != nil {
c.Close()
return err
}
return nil
}
func (c *Client) WriteRead(buf []byte, rd *bufreader.Reader) error {
func (c *Client) WriteRead(buf []byte, conn *Conn) error {
c.mtx.Lock()
defer c.mtx.Unlock()
if err := c.WriteReq(buf); err != nil {
if err := c.WriteReq(buf, conn); err != nil {
return err
}
return c.ReadReply(rd)
return c.ReadReply(conn)
}
func (c *Client) Process(req Req) {
if c.reqs == nil {
c.Run(req)
} else {
c.Queue(req)
}
}
func (c *Client) Queue(req Req) {
req.SetClient(c)
c.mtx.Lock()
c.reqs = append(c.reqs, req)
c.mtx.Unlock()
}
func (c *Client) Run(req Req) {
rd, err := c.readerPool.Get()
if err != nil {
req.SetErr(err)
return
}
defer c.readerPool.Add(rd)
err = c.WriteRead(req.Req(), rd)
conn, _, err := c.ConnPool.Get()
if err != nil {
c.ConnPool.Remove(conn)
req.SetErr(err)
return
}
val, err := req.ParseReply(rd)
err = c.WriteRead(req.Req(), conn)
if err != nil {
c.ConnPool.Remove(conn)
req.SetErr(err)
return
}
val, err := req.ParseReply(conn.Rd)
if err != nil {
c.ConnPool.Remove(conn)
req.SetErr(err)
return
}
c.ConnPool.Add(conn)
req.SetVal(val)
}
func (c *Client) RunQueued() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
c.mtx.Lock()
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
return c.RunReqs(reqs)
}
func (c *Client) RunReqs(reqs []Req) ([]Req, error) {
var multiReq []byte
if len(reqs) == 1 {
multiReq = reqs[0].Req()
@ -150,18 +161,20 @@ func (c *Client) RunQueued() ([]Req, error) {
}
}
rd, err := c.readerPool.Get()
conn, _, err := c.ConnPool.Get()
if err != nil {
return nil, err
}
defer c.readerPool.Add(rd)
err = c.WriteRead(multiReq, rd)
err = c.WriteRead(multiReq, conn)
if err != nil {
return nil, err
}
for _, req := range reqs {
val, err := req.ParseReply(rd)
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {
@ -181,11 +194,11 @@ func (c *Client) Discard() {
}
func (c *Client) Exec() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
c.mtx.Lock()
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
@ -197,13 +210,12 @@ func (c *Client) Exec() ([]Req, error) {
}
multiReq = append(multiReq, PackReq([]string{"EXEC"})...)
rd, err := c.readerPool.Get()
conn, _, err := c.ConnPool.Get()
if err != nil {
return nil, err
}
defer c.readerPool.Add(rd)
err = c.WriteRead(multiReq, rd)
err = c.WriteRead(multiReq, conn)
if err != nil {
return nil, err
}
@ -211,31 +223,32 @@ func (c *Client) Exec() ([]Req, error) {
statusReq := NewStatusReq()
// Parse MULTI command reply.
_, err = statusReq.ParseReply(rd)
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return nil, err
}
// Parse queued replies.
for _ = range reqs {
_, err = statusReq.ParseReply(rd)
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return nil, err
}
}
// Parse number of replies.
line, err := rd.ReadLine('\n')
line, err := conn.Rd.ReadLine('\n')
if err != nil {
return nil, err
}
if line[0] != '*' {
return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, rd.Bytes())
return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, conn.Rd.Bytes())
}
// Parse replies.
for _, req := range reqs {
val, err := req.ParseReply(rd)
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {