diff --git a/pubsub.go b/pubsub.go index c0b0a735..080ad09c 100644 --- a/pubsub.go +++ b/pubsub.go @@ -7,21 +7,22 @@ import ( type PubSubClient struct { *Client - conn *Conn ch chan *Message once sync.Once } func newPubSubClient(client *Client) (*PubSubClient, error) { - conn, _, err := client.ConnPool.Get() + pubSubConn, _, err := client.ConnPool.Get() if err != nil { return nil, err } + client.ConnPool.Remove(pubSubConn) c := &PubSubClient{ - Client: client, - conn: conn, - ch: make(chan *Message), + Client: &Client{ + ConnPool: NewOneConnPool(pubSubConn), + }, + ch: make(chan *Message), } return c, nil } @@ -34,13 +35,17 @@ type Message struct { } func (c *PubSubClient) consumeMessages() { + conn, err := c.conn() + if err != nil { + panic(err) + } req := NewMultiBulkReq() for { // Replies can arrive in batches. // Read whole reply and parse messages one by one. - err := c.ReadReply(c.conn) + err := c.ReadReply(conn) if err != nil { msg := &Message{} msg.Err = err @@ -51,7 +56,7 @@ func (c *PubSubClient) consumeMessages() { for { msg := &Message{} - replyI, err := req.ParseReply(c.conn.Rd) + replyI, err := req.ParseReply(conn.Rd) if err != nil { msg.Err = err c.ch <- msg @@ -74,7 +79,7 @@ func (c *PubSubClient) consumeMessages() { } c.ch <- msg - if !c.conn.Rd.HasUnread() { + if !conn.Rd.HasUnread() { break } } @@ -85,7 +90,12 @@ func (c *PubSubClient) Subscribe(channels ...string) (chan *Message, error) { args := append([]string{"SUBSCRIBE"}, channels...) req := NewMultiBulkReq(args...) - if err := c.WriteReq(req.Req(), c.conn); err != nil { + conn, err := c.conn() + if err != nil { + return nil, err + } + + if err := c.WriteReq(req.Req(), conn); err != nil { return nil, err } @@ -99,5 +109,11 @@ func (c *PubSubClient) Subscribe(channels ...string) (chan *Message, error) { func (c *PubSubClient) Unsubscribe(channels ...string) error { args := append([]string{"UNSUBSCRIBE"}, channels...) req := NewMultiBulkReq(args...) - return c.WriteReq(req.Req(), c.conn) + + conn, err := c.conn() + if err != nil { + return err + } + + return c.WriteReq(req.Req(), conn) } diff --git a/request.go b/request.go index 563d64f6..f44b1002 100644 --- a/request.go +++ b/request.go @@ -3,7 +3,6 @@ package redis import ( "errors" "fmt" - "io" "strconv" "github.com/vmihailenco/bufreader" @@ -34,17 +33,19 @@ func ParseReq(rd *bufreader.Reader) ([]string, error) { if err != nil { return nil, err } + if line[0] != '*' { return []string{string(line)}, nil } + numReplies, err := strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return nil, err + } args := make([]string, 0) - for { + for i := int64(0); i < numReplies; i++ { line, err = rd.ReadLine('\n') if err != nil { - if err == io.EOF { - break - } return nil, err } if line[0] != '$' {