diff --git a/example_test.go b/example_test.go index 412d5c6..7bef980 100644 --- a/example_test.go +++ b/example_test.go @@ -3,15 +3,14 @@ package mqtt_test import ( "context" "errors" - "io/ioutil" - "log" - "net" + "fmt" "os" "os/signal" "syscall" "time" "github.com/pascaldekloe/mqtt" + "github.com/pascaldekloe/mqtt/mqtttest" ) // Publish is a method from mqtt.Client. @@ -24,27 +23,16 @@ var PublishAtLeastOnce func(message []byte, topic string) (ack <-chan error, err var Subscribe func(quit <-chan struct{}, topicFilters ...string) error func init() { - // The log lines serve as example explanation only. - log.SetOutput(ioutil.Discard) - - c := mqtt.NewClient(new(mqtt.Config), func(context.Context) (net.Conn, error) { - return nil, errors.New("won't dial for demo client") - }) - err := c.VolatileSession("demo-client") - if err != nil { - panic(err) - } - c.Close() - - PublishAtLeastOnce = c.PublishAtLeastOnce - Subscribe = c.Subscribe + PublishAtLeastOnce = mqtttest.NewPublishAckStub(mqtt.ErrClosed) + Subscribe = mqtttest.NewSubscribeStub(mqtt.ErrClosed) } // It is good practice to install the client from main. func ExampleNewClient_setup() { client := mqtt.NewClient(&mqtt.Config{WireTimeout: time.Second}, mqtt.NewDialer("tcp", "localhost:1883")) if err := client.VolatileSession("demo-client"); err != nil { - log.Fatal(err) + fmt.Print("exit on ", err) + os.Exit(2) } // launch read-routine @@ -55,24 +43,25 @@ func ExampleNewClient_setup() { switch { case err == nil: // do something with inbound message - log.Printf("📥 %q: %q", channel, message) + fmt.Printf("📥 %q: %q", channel, message) case errors.Is(err, mqtt.ErrClosed): return // terminated case mqtt.IsDeny(err): - log.Fatal(err) // faulty configuration + fmt.Print("unusable configuration: ", err) + os.Exit(2) case errors.As(err, &big): - log.Printf("%d byte content skipped", big.Size) + fmt.Printf("%d byte content skipped", big.Size) case mqtt.IsConnectionRefused(err): - log.Print(err) + fmt.Print(err) // ErrDown for a while time.Sleep(5*time.Minute - time.Second) default: - log.Print("MQTT unavailable: ", err) + fmt.Print("MQTT unavailable: ", err) // ErrDown for short backoff time.Sleep(2 * time.Second) } @@ -90,17 +79,17 @@ func ExampleNewClient_setup() { for sig := range signals { switch sig { case syscall.SIGINT: - log.Print("MQTT close on SIGINT…") + fmt.Print("MQTT close on SIGINT…") err := client.Close() if err != nil { - log.Print(err) + fmt.Print(err) } case syscall.SIGTERM: - log.Print("MQTT disconnect on SIGTERM…") + fmt.Print("MQTT disconnect on SIGTERM…") err := client.Disconnect(nil) if err != nil { - log.Print(err) + fmt.Print(err) } } } @@ -115,36 +104,37 @@ func ExampleClient_PublishAtLeastOnce_hasty() { ack, err := PublishAtLeastOnce([]byte("🍸🆘"), "demo/alert") switch { case err == nil: - log.Print("alert submitted") + fmt.Print("alert submitted") case mqtt.IsDeny(err), errors.Is(err, mqtt.ErrClosed): - log.Print("🚨 alert not send: ", err) + fmt.Print("🚨 alert not send: ", err) return case errors.Is(err, mqtt.ErrMax), errors.Is(err, mqtt.ErrDown): - log.Print("⚠️ alert delay: ", err) + fmt.Print("⚠️ alert delay: ", err) time.Sleep(time.Second / 4) continue default: - log.Print("⚠️ alert delay on persistence malfunction: ", err) + fmt.Print("⚠️ alert delay on persistence malfunction: ", err) time.Sleep(time.Second) continue } for err := range ack { if errors.Is(err, mqtt.ErrClosed) { - log.Print("🚨 alert suspended: ", err) + fmt.Print("🚨 alert suspended: ", err) // Submission will continue when the Client // is restarted with the same Store again. return } - log.Print("⚠️ alert delay on connection malfunction: ", err) + fmt.Print("⚠️ alert delay on connection malfunction: ", err) } - log.Print("alert confirmed") + fmt.Print("alert confirmed") break } // Output: + // 🚨 alert not send: mqtt: client closed } // Error scenario and how to act uppon them. @@ -157,15 +147,15 @@ func ExampleClient_Subscribe_sticky() { err := Subscribe(ctx.Done(), topicFilter) switch { case err == nil: - log.Printf("subscribed to %q", topicFilter) + fmt.Printf("subscribed to %q", topicFilter) return case mqtt.IsDeny(err), errors.Is(err, mqtt.ErrClosed): - log.Print("no subscribe: ", err) + fmt.Print("no subscribe: ", err) return case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): - log.Print("subscribe timeout: ", err) + fmt.Print("subscribe timeout: ", err) return case errors.Is(err, mqtt.ErrMax), errors.Is(err, mqtt.ErrDown): @@ -173,9 +163,10 @@ func ExampleClient_Subscribe_sticky() { default: backoff := 4 * time.Second - log.Printf("subscribe retry in %s on: %s", backoff, err) + fmt.Printf("subscribe retry in %s on: %s", backoff, err) time.Sleep(backoff) } } // Output: + // no subscribe: mqtt: client closed } diff --git a/mqtttest/mqtttest.go b/mqtttest/mqtttest.go index c2f003c..e1a47cc 100644 --- a/mqtttest/mqtttest.go +++ b/mqtttest/mqtttest.go @@ -3,11 +3,22 @@ package mqtttest import ( "bytes" + "errors" + "sync/atomic" "testing" + "time" "github.com/pascaldekloe/mqtt" ) +// NewPublishStub returns a new stub for mqtt.Client Publish with a fixed return +// value. +func NewPublishStub(returnFix error) func(message []byte, topic string) error { + return func(message []byte, topic string) error { + return returnFix + } +} + // Transfer defines a message exchange. type Transfer struct { Message []byte // payload @@ -15,24 +26,26 @@ type Transfer struct { Err error // result } -// PublishMock returns a new mock for mqtt.Client Publish, which compares the +// NewPublishMock returns a new mock for mqtt.Client Publish, which compares the // invocation with want in order of appearance. -func PublishMock(t *testing.T, want ...Transfer) func(message []byte, topic string) error { - var i int +func NewPublishMock(t testing.TB, want ...Transfer) func(message []byte, topic string) error { + t.Helper() + + var wantIndex uint64 t.Cleanup(func() { - if n := len(want) - i; n > 0 { + if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 { t.Errorf("want %d more MQTT publishes", n) } }) return func(message []byte, topic string) error { - if i >= len(want) { + i := atomic.AddUint64(&wantIndex, 1) - 1 + if i >= uint64(len(want)) { t.Errorf("unwanted MQTT publish of %#x to %q", message, topic) return nil } transfer := want[i] - i++ if !bytes.Equal(message, transfer.Message) && topic != transfer.Topic { t.Errorf("got MQTT publish of %#x to %q, want %#x to %q", message, topic, transfer.Message, transfer.Topic) @@ -41,27 +54,133 @@ func PublishMock(t *testing.T, want ...Transfer) func(message []byte, topic stri } } -// Filter defines a (un)subscription exchange. +// AckBlock prevents ack <-chan error submission. +type AckBlock struct { + Delay time.Duration // zero defaults to indefinite +} + +// Error implements the standard error interface. +func (b AckBlock) Error() string { + return "mqtttest: AckBlock used as an error" +} + +// NewPublishAckStub returns a stub for mqtt.Client PublishAtLeastOnce or +// PublishExactlyOnce with a fixed return value. +// +// The ackFix errors are applied to the ack return, with an option for AckBlock +// entries. An mqtt.ErrClosed in the ackFix keeps the ack channel open (without +// an extra AckBlock entry. +func NewPublishAckStub(errFix error, ackFix ...error) func(message []byte, topic string) (ack <-chan error, err error) { + if errFix != nil && len(ackFix) != 0 { + panic("ackFix entries with non-nil errFix") + } + var block AckBlock + for i, err := range ackFix { + switch { + case err == nil: + panic("nil entry in ackFix") + case errors.Is(err, mqtt.ErrClosed): + if i+1 < len(ackFix) { + panic("followup of mqtt.ErrClosed ackFix entry") + } + case errors.As(err, &block): + if block.Delay == 0 && i+1 < len(ackFix) { + panic("followup of indefinite AckBlock ackFix entry") + } + } + } + + return func(message []byte, topic string) (ack <-chan error, err error) { + if errFix != nil { + return nil, errFix + } + + ch := make(chan error, len(ackFix)) + go func() { + var block AckBlock + for _, err := range ackFix { + switch { + default: + ch <- err + case errors.Is(err, mqtt.ErrClosed): + ch <- err + return // without close + case errors.As(err, &block): + if block.Delay == 0 { + return // without close + } + time.Sleep(block.Delay) + } + } + close(ch) + }() + return ch, nil + } +} + +// NewSubscribeStub returns a stub for mqtt.Client Subscribe with a fixed return +// value. +func NewSubscribeStub(returnFix error) func(quit <-chan struct{}, topicFilters ...string) error { + return newSubscribeStub("subscribe", returnFix) +} + +// NewUnsubscribeStub returns a stub for mqtt.Client Unsubscribe with a fixed +// return value. +func NewUnsubscribeStub(returnFix error) func(quit <-chan struct{}, topicFilters ...string) error { + return newSubscribeStub("unsubscribe", returnFix) +} + +func newSubscribeStub(name string, returnFix error) func(quit <-chan struct{}, topicFilters ...string) error { + return func(quit <-chan struct{}, topicFilters ...string) error { + if len(topicFilters) == 0 { + // TODO(pascaldekloe): move validation to internal + // package and then return appropriate errors here. + panic("MQTT " + name + " without topic filters") + } + select { + case <-quit: + return mqtt.ErrCanceled + default: + break + } + return returnFix + } +} + +// Filter defines a subscription exchange. type Filter struct { Topics []string // order is ignored Err error // result } -// SubscribeMock returns a new mock for mqtt.Client Subscribe, -// SubscribeLimitAtMostOnce and SubscribeLimitAtLeastOnce, which compares the -// invocation with want in order of appearece. -func SubscribeMock(t *testing.T, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { - var i int +// NewSubscribeMock returns a new mock for mqtt.Client Subscribe, which compares +// the invocation with want in order of appearece. +func SubscribeMock(t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { + t.Helper() + return newSubscribeMock("subscribe", t, want...) +} + +// NewUnsubscribeMock returns a new mock for mqtt.Client Unsubscribe, which +// compares the invocation with want in order of appearece. +func NewUnsubscribeMock(t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { + t.Helper() + return newSubscribeMock("unsubscribe", t, want...) +} + +func newSubscribeMock(name string, t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { + t.Helper() + + var wantIndex uint64 t.Cleanup(func() { - if i < len(want) { - t.Errorf("want %d more MQTT subscribes", len(want)-i) + if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 { + t.Errorf("want %d more MQTT %ss", n, name) } }) return func(quit <-chan struct{}, topicFilters ...string) error { if len(topicFilters) == 0 { - t.Fatal("MQTT subscribe without topic filters") + t.Fatalf("MQTT %s without topic filters", name) } select { case <-quit: @@ -70,11 +189,11 @@ func SubscribeMock(t *testing.T, want ...Filter) func(quit <-chan struct{}, topi break } - if i >= len(want) { - t.Errorf("unwanted MQTT subscribe of %q", topicFilters) + i := atomic.AddUint64(&wantIndex, 1) - 1 + if i >= uint64(len(want)) { + t.Errorf("unwanted MQTT %s of %q", name, topicFilters) } filter := want[i] - i++ todo := make(map[string]struct{}, len(filter.Topics)) for _, topic := range filter.Topics { @@ -89,14 +208,14 @@ func SubscribeMock(t *testing.T, want ...Filter) func(quit <-chan struct{}, topi } } if len(wrong) != 0 { - t.Errorf("unwanted MQTT subscribe of %q (out of %q)", wrong, filter.Topics) + t.Errorf("unwanted MQTT %s of %q (out of %q)", name, wrong, filter.Topics) } if len(todo) != 0 { var miss []string for filter := range todo { miss = append(miss, filter) } - t.Errorf("no MQTT subscribe of %q (out of %q)", miss, filter.Topics) + t.Errorf("no MQTT %s of %q (out of %q)", name, miss, filter.Topics) } return filter.Err