From a519918aaa71edf7138135a8ca4f8f49fb968fd9 Mon Sep 17 00:00:00 2001 From: "Pascal S. de Kloe" Date: Mon, 8 Feb 2021 20:55:50 +0100 Subject: [PATCH] Support messages beyond the read buffer capacity. --- client.go | 112 +++++++++++++++++++++++++------- cmd/mqttc/main.go | 37 +++++++---- example_test.go | 5 +- integration/integration_test.go | 1 - request_test.go | 20 ++++++ 5 files changed, 139 insertions(+), 36 deletions(-) diff --git a/client.go b/client.go index 8d638c4..ec51918 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,11 @@ import ( "time" ) +// ReadBufSize covers inbound packet reception. BigMessage still uses the buffer +// to parse everything up until the message payload, which makes a worst-case of +// 2 B size prefix + 64 KiB topic + 2 B packet identifier. +var readBufSize = 128 * 1024 + // ErrDown signals no-service after a failed connect attempt. // The error state will clear once a connect retry succeeds. var ErrDown = errors.New("mqtt: connection unavailable") @@ -71,15 +76,11 @@ func NewTLSDialer(network, address string, config *tls.Config) Dialer { } // Config is a Client configuration. Dialer and Store are the only required -// fields, although a specific BufSize and a non-zero WireTimeout comes highly -// recommended. +// fields, although the WireTimeout comes highly recommended. type Config struct { Dialer // chooses the broker Store // persists the session - // BufSize defines the read buffer capacity, which goes up to 256 MiB. - BufSize int - // WireTimeout sets the minimim transfer rate as one byte per duration. // Zero disables timeout protection, which leaves the Client vulnerable // to blocking on stale connections. @@ -259,6 +260,8 @@ type Client struct { // The read routine uses this reusable buffer for packet submission. pendingAck [4]byte + // The read routine parks reception beyond readBufSize. + bigMessage *BigMessage } // NewClient returns a new Client. Configuration errors result in IsDeny on @@ -625,14 +628,14 @@ func (c *Client) peekPacket() (head byte, err error) { } lastN := len(c.peek) - c.peek, err = c.r.Peek(int(size)) - if err == nil { // OK - return head, nil + c.peek, err = c.r.Peek(size) + switch { + case err == nil: // OK + return head, err + case head>>4 == typePUBLISH && errors.Is(err, bufio.ErrBufferFull): + return head, &BigMessage{Client: c, Size: size} } - // TODO(pascaldekloe): if errors.Is(err, bufio.ErrBufferFull) { - // return head, BigPacketError{c, io.MultiReader(bytes.NewReader(c.peek), io.LimitReader(c.r, l-len(c.peek)))} - // Allow deadline expiry if at least one byte was transferred. var ne net.Error if len(c.peek) > lastN && errors.As(err, &ne) && ne.Timeout() { @@ -706,7 +709,7 @@ func (c *Client) handshake(conn net.Conn, requestPacket []byte) (*bufio.Reader, return nil, err } - r := bufio.NewReaderSize(conn, c.BufSize) + r := bufio.NewReaderSize(conn, readBufSize) // Apply the deadline to the "entire" 4-byte response. if c.WireTimeout != 0 { @@ -747,9 +750,18 @@ func (c *Client) handshake(conn net.Conn, requestPacket []byte) (*bufio.Reader, // Alternatively, use either Disconnect or Close to prevent a confirmation from // being send. func (c *Client) ReadSlices() (message, topic []byte, err error) { - // skip previous packet, if any - c.r.Discard(len(c.peek)) // no errors guaranteed - c.peek = nil // flush + if len(c.peek) != 0 { + // skip previous packet, if any + c.r.Discard(len(c.peek)) // no errors guaranteed + c.peek = nil // flush + } + if c.bigMessage != nil { + _, err = c.r.Discard(c.bigMessage.Size) + if err != nil { + c.block() + return nil, nil, err + } + } if c.readConn == nil { if err := c.connect(); err != nil { @@ -780,18 +792,40 @@ func (c *Client) ReadSlices() (message, topic []byte, err error) { // process packets until a PUBLISH appears for { + var bigp *BigMessage head, err := c.peekPacket() - if err != nil { - if errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) { - // got interrupted - if err := c.connect(); err != nil { - c.readConn = nil - return nil, nil, err - } + switch { + case err == nil: + break - continue // with new connection + case errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe): + // got interrupted + if err := c.connect(); err != nil { + c.readConn = nil + return nil, nil, err } + continue // with new connection + + case errors.As(err, &bigp): + if head>>4 == typePUBLISH { + message, topic, err = c.onPUBLISH(head) + if err != nil { + // If the packet is malformed then + // bigp is not the issue anymore. + c.block() + return nil, nil, err + } + bigp.Topic = string(topic) // copy + done := readBufSize - len(message) + bigp.Size -= done + c.r.Discard(done) // no errors guaranteed + } + c.peek = nil + c.bigMessage = bigp + return nil, nil, bigp + + default: c.block() return nil, nil, err } @@ -861,6 +895,38 @@ func (c *Client) block() { c.readConn = nil } +// BigMessage signals reception beyond the read buffer capacity. +// Receivers may or may not allocate the memory with ReadAll. +// The next ReadSlices will acknowledge reception either way. +type BigMessage struct { + *Client // source + Topic string // destinition + Size int // byte count +} + +// Error implements the standard error interface. +func (e *BigMessage) Error() string { + return fmt.Sprintf("mqtt: %d B message exceeds read buffer capacity", e.Size) +} + +// ReadAll returns the message in a new/dedicated buffer. Messages can be read +// only once, after reception (from ReadSlices), and before the next ReadSlices. +// The invocation must occur from within the read-routine. +func (e *BigMessage) ReadAll() ([]byte, error) { + if e.bigMessage != e { + return nil, errors.New("mqtt: read window expired for a big message") + } + e.bigMessage = nil + + message := make([]byte, e.Size) + _, err := io.ReadFull(e.Client.r, message) + if err != nil { + e.Client.block() + return nil, err + } + return message, nil +} + var errDupe = errors.New("mqtt: duplicate reception") // OnPUBLISH slices an inbound message from Client.peek. diff --git a/cmd/mqttc/main.go b/cmd/mqttc/main.go index d37f24a..dd91fd5 100644 --- a/cmd/mqttc/main.go +++ b/cmd/mqttc/main.go @@ -218,19 +218,11 @@ func main() { // read routine for { + var big *mqtt.BigMessage message, topic, err := client.ReadSlices() switch { case err == nil: - switch { - case *topicFlag && *quoteFlag: - fmt.Printf("%q%s%q%s", topic, *prefixFlag, message, *suffixFlag) - case *topicFlag: - fmt.Printf("%s%s%s%s", topic, *prefixFlag, message, *suffixFlag) - case *quoteFlag: - fmt.Printf("%s%q%s", *prefixFlag, message, *suffixFlag) - default: - fmt.Printf("%s%s%s", *prefixFlag, message, *suffixFlag) - } + printMessage(message, topic) case errors.Is(err, mqtt.ErrClosed): os.Exit(<-exitStatus) @@ -238,6 +230,15 @@ func main() { case mqtt.IsDeny(err): // illegal configuration log.Fatal(err) + case errors.As(err, &big): + message, err := big.ReadAll() + if err != nil { + log.Print(err) + exit(1) + return + } + printMessage(message, big.Topic) + default: log.Print(err) @@ -254,11 +255,25 @@ func main() { os.Exit(9) } - go exit(1) + exit(1) + return } } } +func printMessage(message, topic interface{}) { + switch { + case *topicFlag && *quoteFlag: + fmt.Printf("%q%s%q%s", topic, *prefixFlag, message, *suffixFlag) + case *topicFlag: + fmt.Printf("%s%s%s%s", topic, *prefixFlag, message, *suffixFlag) + case *quoteFlag: + fmt.Printf("%s%q%s", *prefixFlag, message, *suffixFlag) + default: + fmt.Printf("%s%s%s", *prefixFlag, message, *suffixFlag) + } +} + func applySignals() { signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) diff --git a/example_test.go b/example_test.go index c28741e..4a1fe05 100644 --- a/example_test.go +++ b/example_test.go @@ -45,11 +45,11 @@ func ExampleNewClient_setup() { Dialer: mqtt.NewDialer("tcp", "localhost:1883"), Store: mqtt.NewVolatileStore("demo-client"), WireTimeout: time.Second, - BufSize: 8192, }) // launch read-routine go func() { + var big *mqtt.BigMessage for { message, channel, err := client.ReadSlices() switch { @@ -63,6 +63,9 @@ func ExampleNewClient_setup() { case mqtt.IsDeny(err): log.Fatal(err) // faulty configuration + case errors.As(err, &big): + log.Printf("%d byte content skipped", big.Size) + case mqtt.IsConnectionRefused(err): log.Print(err) // ErrDown for a while diff --git a/integration/integration_test.go b/integration/integration_test.go index afb72c2..7aaa7ff 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -49,7 +49,6 @@ func race(t *testing.T, host string, deliveryLevel int) { client := mqtt.NewClient(&mqtt.Config{ Dialer: mqtt.NewDialer("tcp", net.JoinHostPort(host, "1883")), WireTimeout: time.Second, - BufSize: 1024, Store: mqtt.NewVolatileStore(t.Name()), CleanSession: true, AtLeastOnceMax: testN, diff --git a/request_test.go b/request_test.go index 8d1dab0..7dcf857 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "net" + "strings" "testing" "time" ) @@ -70,6 +71,10 @@ func newClientPipe(t *testing.T, want ...reception) (*Client, net.Conn) { for { message, topic, err := client.ReadSlices() + if big := (*BigMessage)(nil); errors.As(err, &big) { + topic = []byte(big.Topic) + message, err = big.ReadAll() + } switch { case err == nil: switch { @@ -247,6 +252,21 @@ func TestReceivePublishExactlyOnce(t *testing.T) { wantPacketHex(t, conn, "7002abcd") // PUBCOMP } +func TestReceivePublishAtLeastOnceBig(t *testing.T) { + const bigN = 256 * 1024 + if bigN <= readBufSize { + t.Fatal("test sample does not exceed the read buffer") + } + + _, conn := newClientPipe(t, reception{Message: strings.Repeat("A", bigN), Topic: "bam"}) + + sendPacketHex(t, conn, hex.EncodeToString([]byte{ + 0x32, 0x87, 0x80, 0x10, + 0, 3, 'b', 'a', 'm', + 0xab, 0xcd})+strings.Repeat("41", bigN)) + wantPacketHex(t, conn, "4002abcd") // PUBACK +} + func testAckErrors(t *testing.T, ack <-chan error, want ...error) { timeout := time.NewTimer(time.Second) defer timeout.Stop()