1
0
mirror of https://github.com/go-mqtt/mqtt.git synced 2025-07-31 04:04:21 +03:00
Files
mqtt/mqtttest/mqtttest.go
Pascal S. de Kloe daf7cf26cc Rebranding.
2021-06-27 18:11:10 +02:00

280 lines
7.4 KiB
Go

// Package mqtttest provides utilities for MQTT testing.
package mqtttest
import (
"bytes"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/go-mqtt/mqtt"
)
// Transfer defines a message exchange.
type Transfer struct {
Message []byte // payload
Topic string // destination
Err error // result
}
// NewReadSlicesStub returns a new stub for mqtt.Client ReadSlices with a fixed
// return value.
func NewReadSlicesStub(fix Transfer) func() (message, topic []byte, err error) {
return func() (message, topic []byte, err error) {
// use copies to prevent some hard to trace issues
message = make([]byte, len(fix.Message))
copy(message, fix.Message)
topic = []byte(fix.Topic)
return message, topic, fix.Err
}
}
// NewReadSlicesMock returns a new mock for mqtt.Client ReadSlices, which
// returns the Transfers in order of appearance.
func NewReadSlicesMock(t testing.TB, want ...Transfer) func() (message, topic []byte, err error) {
t.Helper()
var wantIndex uint64
t.Cleanup(func() {
t.Helper()
if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 {
t.Errorf("want %d more MQTT ReadSlices", n)
}
})
return func() (message, topic []byte, err error) {
t.Helper()
i := atomic.AddUint64(&wantIndex, 1) - 1
if i >= uint64(len(want)) {
err = errors.New("unwanted MQTT ReadSlices")
t.Error(err)
return
}
return NewReadSlicesStub(want[i])()
}
}
// NewPublishMock returns a new mock for mqtt.Client Publish, which compares the
// invocation with want in order of appearance.
func NewPublishMock(t testing.TB, want ...Transfer) func(quit <-chan struct{}, message []byte, topic string) error {
t.Helper()
var wantIndex uint64
t.Cleanup(func() {
if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 {
t.Errorf("want %d more MQTT publishes", n)
}
})
return func(quit <-chan struct{}, message []byte, topic string) error {
t.Helper()
select {
case <-quit:
return mqtt.ErrCanceled
default:
break
}
i := atomic.AddUint64(&wantIndex, 1) - 1
if i >= uint64(len(want)) {
t.Errorf("unwanted MQTT publish of %#x to %q", message, topic)
return nil
}
transfer := want[i]
if !bytes.Equal(message, transfer.Message) && topic != transfer.Topic {
t.Errorf("got MQTT publish of %#x to %q, want %#x to %q", message, topic, transfer.Message, transfer.Topic)
}
return transfer.Err
}
}
// NewPublishStub returns a new stub for mqtt.Client Publish with a fixed return
// value.
func NewPublishStub(fix error) func(quit <-chan struct{}, message []byte, topic string) error {
return func(quit <-chan struct{}, message []byte, topic string) error {
select {
case <-quit:
return mqtt.ErrCanceled
default:
return fix
}
}
}
// ExchangeBlock prevents exchange <-chan error submission.
type ExchangeBlock struct {
Delay time.Duration // zero defaults to indefinite
}
// Error implements the standard error interface.
func (b ExchangeBlock) Error() string {
return "mqtttest: ExchangeBlock used as an error"
}
// NewPublishExchangeStub returns a stub for mqtt.Client PublishAtLeastOnce or
// PublishExactlyOnce with a fixed return value.
//
// The exchangeFix errors are applied to the exchange return, with an option for
// ExchangeBlock entries. An mqtt.ErrClosed in the exchangeFix keeps the
// exchange channel open (without an extra ExchangeBlock entry).
func NewPublishExchangeStub(errFix error, exchangeFix ...error) func(message []byte, topic string) (exchange <-chan error, err error) {
if errFix != nil && len(exchangeFix) != 0 {
panic("exchangeFix entries with non-nil errFix")
}
var block ExchangeBlock
for i, err := range exchangeFix {
switch {
case err == nil:
panic("nil entry in exchangeFix")
case errors.Is(err, mqtt.ErrClosed):
if i+1 < len(exchangeFix) {
panic("followup on mqtt.ErrClosed exchangeFix entry")
}
case errors.As(err, &block):
if block.Delay == 0 && i+1 < len(exchangeFix) {
panic("followup on indefinite ExchangeBlock exchangeFix entry")
}
}
}
return func(message []byte, topic string) (exchange <-chan error, err error) {
if errFix != nil {
return nil, errFix
}
ch := make(chan error, len(exchangeFix))
go func() {
var block ExchangeBlock
for _, err := range exchangeFix {
switch {
default:
ch <- err
case errors.Is(err, mqtt.ErrClosed):
ch <- err
return // without close
case errors.As(err, &block):
if block.Delay == 0 {
return // without close
}
time.Sleep(block.Delay)
}
}
close(ch)
}()
return ch, nil
}
}
// NewSubscribeStub returns a stub for mqtt.Client Subscribe with a fixed return
// value.
func NewSubscribeStub(fix error) func(quit <-chan struct{}, topicFilters ...string) error {
return newSubscribeStub("subscribe", fix)
}
// NewUnsubscribeStub returns a stub for mqtt.Client Unsubscribe with a fixed
// return value.
func NewUnsubscribeStub(fix error) func(quit <-chan struct{}, topicFilters ...string) error {
return newSubscribeStub("unsubscribe", fix)
}
func newSubscribeStub(name string, fix error) func(quit <-chan struct{}, topicFilters ...string) error {
return func(quit <-chan struct{}, topicFilters ...string) error {
if len(topicFilters) == 0 {
// TODO(pascaldekloe): move validation to internal
// package and then return appropriate errors here.
panic("MQTT " + name + " without topic filters")
}
select {
case <-quit:
return mqtt.ErrCanceled
default:
break
}
return fix
}
}
// Filter defines a subscription exchange.
type Filter struct {
Topics []string // order is ignored
Err error // result
}
// NewSubscribeMock returns a new mock for mqtt.Client Subscribe, which compares
// the invocation with want in order of appearece.
func NewSubscribeMock(t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error {
t.Helper()
return newSubscribeMock("subscribe", t, want...)
}
// NewUnsubscribeMock returns a new mock for mqtt.Client Unsubscribe, which
// compares the invocation with want in order of appearece.
func NewUnsubscribeMock(t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error {
t.Helper()
return newSubscribeMock("unsubscribe", t, want...)
}
func newSubscribeMock(name string, t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error {
t.Helper()
var wantIndex uint64
t.Cleanup(func() {
if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 {
t.Errorf("want %d more MQTT %ss", n, name)
}
})
return func(quit <-chan struct{}, topicFilters ...string) error {
t.Helper()
if len(topicFilters) == 0 {
t.Fatalf("MQTT %s without topic filters", name)
}
select {
case <-quit:
return mqtt.ErrCanceled
default:
break
}
i := atomic.AddUint64(&wantIndex, 1) - 1
if i >= uint64(len(want)) {
t.Errorf("unwanted MQTT %s of %q", name, topicFilters)
}
filter := want[i]
todo := make(map[string]struct{}, len(filter.Topics))
for _, topic := range filter.Topics {
todo[topic] = struct{}{}
}
var wrong []string
for _, filter := range topicFilters {
if _, ok := todo[filter]; ok {
delete(todo, filter)
} else {
wrong = append(wrong, filter)
}
}
if len(wrong) != 0 {
t.Errorf("unwanted MQTT %s of %q (out of %q)", name, wrong, filter.Topics)
}
if len(todo) != 0 {
var miss []string
for filter := range todo {
miss = append(miss, filter)
}
t.Errorf("no MQTT %s of %q (out of %q)", name, miss, filter.Topics)
}
return filter.Err
}
}