From d5046be0f5f99325aa4733ce4d30e3d4a9fcd5f1 Mon Sep 17 00:00:00 2001 From: "Pascal S. de Kloe" Date: Thu, 18 Feb 2021 23:06:09 +0100 Subject: [PATCH] FIX: mqtttest signature mismatch with Publish. --- mqtttest/mqtttest.go | 28 +++++++++++++++++++++++----- mqtttest/mqtttest_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 mqtttest/mqtttest_test.go diff --git a/mqtttest/mqtttest.go b/mqtttest/mqtttest.go index 51b1fa8..7c35b06 100644 --- a/mqtttest/mqtttest.go +++ b/mqtttest/mqtttest.go @@ -13,9 +13,14 @@ import ( // NewPublishStub returns a new stub for mqtt.Client Publish with a fixed return // value. -func NewPublishStub(returnFix error) func(message []byte, topic string) error { - return func(message []byte, topic string) error { - return returnFix +func NewPublishStub(returnFix error) func(quit <-chan struct{}, message []byte, topic string) error { + return func(quit <-chan struct{}, message []byte, topic string) error { + select { + case <-quit: + return mqtt.ErrCanceled + default: + return returnFix + } } } @@ -33,6 +38,12 @@ func NewReadSlicesMock(t testing.TB, want ...Transfer) func() (message, topic [] var wantIndex uint64 + t.Cleanup(func() { + if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 { + t.Errorf("want %d more MQTT ReadSlices", n) + } + }) + return func() (message, topic []byte, err error) { i := atomic.AddUint64(&wantIndex, 1) - 1 if i >= uint64(len(want)) { @@ -51,7 +62,7 @@ func NewReadSlicesMock(t testing.TB, want ...Transfer) func() (message, topic [] // NewPublishMock returns a new mock for mqtt.Client Publish, which compares the // invocation with want in order of appearance. -func NewPublishMock(t testing.TB, want ...Transfer) func(message []byte, topic string) error { +func NewPublishMock(t testing.TB, want ...Transfer) func(quit <-chan struct{}, message []byte, topic string) error { t.Helper() var wantIndex uint64 @@ -62,7 +73,14 @@ func NewPublishMock(t testing.TB, want ...Transfer) func(message []byte, topic s } }) - return func(message []byte, topic string) error { + return func(quit <-chan struct{}, message []byte, topic string) error { + select { + case <-quit: + return mqtt.ErrCanceled + default: + break + } + i := atomic.AddUint64(&wantIndex, 1) - 1 if i >= uint64(len(want)) { t.Errorf("unwanted MQTT publish of %#x to %q", message, topic) diff --git a/mqtttest/mqtttest_test.go b/mqtttest/mqtttest_test.go new file mode 100644 index 0000000..bdaf811 --- /dev/null +++ b/mqtttest/mqtttest_test.go @@ -0,0 +1,37 @@ +package mqtttest_test + +import ( + "testing" + + "github.com/pascaldekloe/mqtt" + "github.com/pascaldekloe/mqtt/mqtttest" +) + +// Signatures +var ( + client mqtt.Client + subscribe = client.Subscribe + unsubscribe = client.Unsubscribe + publish = client.Publish + publishAck = client.PublishAtLeastOnce + readSlices = client.ReadSlices +) + +// Won't compile on failure. +func TestSignatureMatch(t *testing.T) { + var c mqtt.Client + // check dupe assumptions + subscribe = c.SubscribeLimitAtMostOnce + subscribe = c.SubscribeLimitAtLeastOnce + publishAck = c.PublishExactlyOnce + + // check fits + readSlices = mqtttest.NewReadSlicesMock(t) + publish = mqtttest.NewPublishMock(t) + publish = mqtttest.NewPublishStub(nil) + publishAck = mqtttest.NewPublishAckStub(nil) + subscribe = mqtttest.NewSubscribeMock(t) + subscribe = mqtttest.NewSubscribeStub(nil) + unsubscribe = mqtttest.NewUnsubscribeMock(t) + unsubscribe = mqtttest.NewUnsubscribeStub(nil) +}