1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

Add support for connection initialisation.

This commit is contained in:
Vladimir Mihailenco
2012-08-06 11:33:49 +03:00
parent 4e6fa48b48
commit c5c8ec6b0c
4 changed files with 116 additions and 20 deletions

View File

@ -28,6 +28,10 @@ func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc {
}
func AuthSelectFunc(password string, db int64) InitConnFunc {
if password == "" && db < 0 {
return nil
}
return func(client *Client) error {
if password != "" {
_, err := client.Auth(password).Reply()
@ -36,9 +40,11 @@ func AuthSelectFunc(password string, db int64) InitConnFunc {
}
}
_, err := client.Select(db).Reply()
if err != nil {
return err
if db >= 0 {
_, err := client.Select(db).Reply()
if err != nil {
return err
}
}
return nil
@ -51,7 +57,7 @@ func createReader() (*bufreader.Reader, error) {
type Client struct {
mtx sync.Mutex
ConnPool *ConnPool
ConnPool ConnPool
InitConn InitConnFunc
reqs []Req
@ -59,7 +65,7 @@ type Client struct {
func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client {
return &Client{
ConnPool: NewConnPool(openConn, closeConn, 10),
ConnPool: NewMultiConnPool(openConn, closeConn, 10),
InitConn: initConn,
}
}
@ -76,6 +82,23 @@ func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64)
)
}
func (c *Client) conn() (*Conn, error) {
conn, isNew, err := c.ConnPool.Get()
if err != nil {
return nil, err
}
if isNew && c.InitConn != nil {
client := &Client{
ConnPool: NewOneConnPool(conn),
}
err = c.InitConn(client)
if err != nil {
return nil, err
}
}
return conn, nil
}
func (c *Client) WriteReq(buf []byte, conn *Conn) error {
_, err := conn.RW.Write(buf)
return err
@ -120,7 +143,7 @@ func (c *Client) Queue(req Req) {
}
func (c *Client) Run(req Req) {
conn, _, err := c.ConnPool.Get()
conn, err := c.conn()
if err != nil {
req.SetErr(err)
return
@ -154,7 +177,7 @@ func (c *Client) RunQueued() ([]Req, error) {
c.reqs = make([]Req, 0)
c.mtx.Unlock()
conn, _, err := c.ConnPool.Get()
conn, err := c.conn()
if err != nil {
return nil, err
}
@ -223,7 +246,7 @@ func (c *Client) Exec() ([]Req, error) {
c.reqs = make([]Req, 0)
c.mtx.Unlock()
conn, _, err := c.ConnPool.Get()
conn, err := c.conn()
if err != nil {
return nil, err
}