mirror of
https://github.com/go-mqtt/mqtt.git
synced 2025-08-08 22:42:05 +03:00
262 lines
6.0 KiB
Go
262 lines
6.0 KiB
Go
package integration
|
||
|
||
import (
|
||
"context"
|
||
"encoding/binary"
|
||
"errors"
|
||
"net"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/go-mqtt/mqtt"
|
||
)
|
||
|
||
// BatchSize is a reasonable number of messages which should not cause any of
|
||
// them to be dropped (by the broker) when send sequentially.
|
||
const batchSize = 99
|
||
const batchTimeout = time.Minute
|
||
|
||
func hosts(tb testing.TB) []string {
|
||
s, ok := os.LookupEnv("MQTT_HOSTS")
|
||
if !ok {
|
||
tb.Skip("no test targets without MQTT_HOSTS environment variable")
|
||
}
|
||
return strings.Fields(s)
|
||
}
|
||
|
||
// NewTestClient returns an instance for testing.
|
||
func newTestClient(t *testing.T, host string, config *mqtt.Config) (client *mqtt.Client, messages <-chan uint64) {
|
||
config.Dialer = mqtt.NewDialer("tcp", net.JoinHostPort(host, "1883"))
|
||
config.PauseTimeout = 2 * time.Second
|
||
config.CleanSession = true
|
||
client, err := mqtt.VolatileSession(t.Name(), config)
|
||
if err != nil {
|
||
t.Fatal("volatile session error:", err)
|
||
}
|
||
|
||
// messages contain their respective sequence number
|
||
ch := make(chan uint64, 16)
|
||
t.Cleanup(func() {
|
||
err := client.Close()
|
||
if err != nil {
|
||
t.Error("client close error:", err)
|
||
}
|
||
// await read-routine exit
|
||
seqNo, ok := <-ch
|
||
if ok {
|
||
t.Errorf("got message # %d after close", seqNo)
|
||
}
|
||
})
|
||
// launch read routine
|
||
go func() {
|
||
defer close(ch)
|
||
for {
|
||
message, topic, err := client.ReadSlices()
|
||
switch {
|
||
case err == nil:
|
||
if len(message) != 8 {
|
||
t.Errorf("unexpected message %#x on topic %q", message, topic)
|
||
} else {
|
||
ch <- binary.LittleEndian.Uint64(message)
|
||
}
|
||
|
||
case errors.Is(err, mqtt.ErrClosed):
|
||
return
|
||
|
||
default:
|
||
t.Log(err)
|
||
time.Sleep(time.Second / 2)
|
||
}
|
||
}
|
||
}()
|
||
|
||
for {
|
||
err := client.Subscribe(nil, t.Name())
|
||
switch {
|
||
case err == nil:
|
||
return client, ch
|
||
case errors.Is(err, mqtt.ErrDown):
|
||
time.Sleep(10 * time.Millisecond)
|
||
continue
|
||
default:
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestRace(t *testing.T) {
|
||
for _, host := range hosts(t) {
|
||
t.Run(host, func(t *testing.T) {
|
||
t.Run("at-most-once", func(t *testing.T) {
|
||
client, messages := newTestClient(t, host, new(mqtt.Config))
|
||
raceAtLevel(t, client, messages, 0)
|
||
})
|
||
t.Run("at-least-once", func(t *testing.T) {
|
||
client, messages := newTestClient(t, host, &mqtt.Config{
|
||
AtLeastOnceMax: 9,
|
||
})
|
||
raceAtLevel(t, client, messages, 1)
|
||
})
|
||
t.Run("exactly-once", func(t *testing.T) {
|
||
client, messages := newTestClient(t, host, &mqtt.Config{
|
||
ExactlyOnceMax: 9,
|
||
})
|
||
raceAtLevel(t, client, messages, 2)
|
||
})
|
||
})
|
||
}
|
||
}
|
||
|
||
func raceAtLevel(t *testing.T, client *mqtt.Client, messages <-chan uint64, deliveryLevel int) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||
defer cancel()
|
||
testTopic := t.Name()
|
||
|
||
launch := make(chan struct{})
|
||
|
||
// install contenders
|
||
var wg sync.WaitGroup
|
||
defer wg.Wait()
|
||
wg.Add(batchSize)
|
||
for n := uint64(1); n <= batchSize; n++ {
|
||
go func(seqNo uint64) {
|
||
defer wg.Done()
|
||
|
||
var message [8]byte
|
||
binary.LittleEndian.PutUint64(message[:], seqNo)
|
||
|
||
<-launch
|
||
|
||
var exchange <-chan error
|
||
Publish:
|
||
for {
|
||
var err error
|
||
switch deliveryLevel {
|
||
case 0:
|
||
err = client.Publish(ctx.Done(), message[:], testTopic)
|
||
case 1:
|
||
exchange, err = client.PublishAtLeastOnce(message[:], testTopic)
|
||
case 2:
|
||
exchange, err = client.PublishExactlyOnce(message[:], testTopic)
|
||
}
|
||
switch {
|
||
case err == nil:
|
||
break Publish
|
||
case errors.Is(err, mqtt.ErrMax):
|
||
time.Sleep(200 * time.Microsecond)
|
||
case errors.Is(err, mqtt.ErrClosed):
|
||
return
|
||
default:
|
||
t.Errorf("publish #%d error: %s", seqNo, err)
|
||
return
|
||
}
|
||
}
|
||
|
||
if deliveryLevel != 0 {
|
||
err, ok := <-exchange
|
||
if ok {
|
||
t.Errorf("publish # %d exchange error: %s", seqNo, err)
|
||
}
|
||
}
|
||
}(n)
|
||
}
|
||
|
||
time.Sleep(50 * time.Millisecond)
|
||
close(launch)
|
||
|
||
timeout := time.After(batchTimeout)
|
||
for i := 0; i < batchSize; i++ {
|
||
select {
|
||
case _, ok := <-messages:
|
||
if !ok {
|
||
t.Fatalf("want %d more messages", batchSize-i)
|
||
}
|
||
case <-timeout:
|
||
t.Fatalf("timeout; want %d more messages", batchSize-i)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestRoundtrip(t *testing.T) {
|
||
for _, host := range hosts(t) {
|
||
t.Run(host, func(t *testing.T) {
|
||
const testN = 17_000 // causes an mqtt.publishIDMask overflow
|
||
t.Run("at-least-once", func(t *testing.T) {
|
||
client, messages := newTestClient(t, host, &mqtt.Config{
|
||
AtLeastOnceMax: 9,
|
||
})
|
||
for i := 0; i < testN; i += batchSize {
|
||
testRoundtripBatch(t, client, client.PublishAtLeastOnce, messages)
|
||
}
|
||
})
|
||
|
||
t.Run("exactly-once", func(t *testing.T) {
|
||
client, messages := newTestClient(t, host, &mqtt.Config{
|
||
ExactlyOnceMax: 9,
|
||
})
|
||
for i := 0; i < testN; i += batchSize {
|
||
testRoundtripBatch(t, client, client.PublishExactlyOnce, messages)
|
||
}
|
||
})
|
||
})
|
||
}
|
||
}
|
||
|
||
func testRoundtripBatch(t *testing.T, client *mqtt.Client,
|
||
publish func(message []byte, topic string) (exchange <-chan error, err error),
|
||
messages <-chan uint64) {
|
||
testTopic := t.Name()
|
||
|
||
var wg sync.WaitGroup
|
||
defer wg.Wait()
|
||
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
|
||
message := make([]byte, 8)
|
||
for n := uint64(1); n <= batchSize; {
|
||
binary.LittleEndian.PutUint64(message, n)
|
||
exchange, err := publish(message, testTopic)
|
||
switch {
|
||
case err == nil:
|
||
n++
|
||
|
||
wg.Add(1)
|
||
go func(seqNo uint64, exchange <-chan error) {
|
||
defer wg.Done()
|
||
err, ok := <-exchange
|
||
if ok {
|
||
t.Errorf("publish # %d exchange error: %s", seqNo, err)
|
||
}
|
||
}(n, exchange)
|
||
case errors.Is(err, mqtt.ErrMax):
|
||
time.Sleep(time.Millisecond)
|
||
case errors.Is(err, mqtt.ErrClosed):
|
||
return
|
||
default:
|
||
t.Errorf("publish # %d error: %s", n, err)
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
timeout := time.After(batchTimeout)
|
||
for n := 1; n <= batchSize; n++ {
|
||
select {
|
||
case seqNo, ok := <-messages:
|
||
if !ok {
|
||
t.Fatalf("did not receive message # %d–%d", n, batchSize)
|
||
}
|
||
if seqNo != uint64(n) {
|
||
t.Errorf("want message # %d, got # %d", n, seqNo)
|
||
}
|
||
case <-timeout:
|
||
t.Fatalf("timeout before message # %d-%d", n, batchSize)
|
||
}
|
||
}
|
||
}
|