From 5393197e587ca844a6768f95ca1b61ef93d2972c Mon Sep 17 00:00:00 2001 From: "Pascal S. de Kloe" Date: Wed, 17 Mar 2021 13:26:53 +0100 Subject: [PATCH] Await mqttc exit code, as all exits are defined. --- cmd/mqttc/main.go | 98 +++++++++++++++++++---------------------------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/cmd/mqttc/main.go b/cmd/mqttc/main.go index d57770f..617ba5b 100644 --- a/cmd/mqttc/main.go +++ b/cmd/mqttc/main.go @@ -113,19 +113,18 @@ func parseConfig() (clientID string, config *mqtt.Config) { var exitStatus = make(chan int, 1) -func setExitStatus(code int) { +func failMQTT(client *mqtt.Client, err error) { + log.Print(err) + select { - case exitStatus <- code: + case exitStatus <- 1: default: // exit status already defined } -} -func getExitStatus() (code int) { - select { - case code = <-exitStatus: - default: // stays zero + err = client.Close() + if err != nil { + log.Print(err) } - return } func main() { @@ -166,16 +165,10 @@ func main() { case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): return case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): - log.Printf("%s: publish timeout (%s)", name, err) - setExitStatus(1) - if err := client.Close(); err != nil { - log.Print(err) - } + failMQTT(client, fmt.Errorf("%s: publish timeout; %s", name, err)) return default: - log.Print(err) - setExitStatus(1) - _ = client.Close() + failMQTT(client, err) return } } @@ -193,16 +186,10 @@ func main() { case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): return case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): - log.Printf("%s: subscribe timeout (%s)", name, err) - setExitStatus(1) - if err := client.Close(); err != nil { - log.Print(err) - } + failMQTT(client, fmt.Errorf("%s: subscribe timeout; %s", name, err)) return default: - log.Print(err) - setExitStatus(1) - _ = client.Close() + failMQTT(client, err) return } } @@ -218,16 +205,10 @@ func main() { case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): return case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): - log.Printf("%s: ping timeout (%s)", name, err) - setExitStatus(1) - if err := client.Close(); err != nil { - log.Print(err) - } + failMQTT(client, fmt.Errorf("%s: ping timeout; %s", name, err)) return default: - log.Print(err) - setExitStatus(1) - _ = client.Close() + failMQTT(client, err) return } } @@ -238,19 +219,20 @@ func main() { defer cancel() err := client.Disconnect(ctx.Done()) switch { - case err == nil, errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + case err == nil: + exitStatus <- 0 + case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + // exit status defined by cause break default: log.Print(err) - setExitStatus(1) - return + exitStatus <- 1 } + return } }() - defer os.Exit(getExitStatus()) - - // read routine + // Read routine runs until mqtt.Client Close or Disconnect. var big *mqtt.BigMessage for { message, topic, err := client.ReadSlices() @@ -259,19 +241,18 @@ func main() { printMessage(message, topic) case errors.Is(err, mqtt.ErrClosed): - return + os.Exit(<-exitStatus) case errors.As(err, &big): message, err := big.ReadAll() if err != nil { - log.Print(err) - setExitStatus(1) - return + failMQTT(client, err) + } else { + printMessage(message, big.Topic) } - printMessage(message, big.Topic) default: - log.Print(err) + failMQTT(client, err) switch { case errors.Is(err, mqtt.ErrProtocolLevel): @@ -285,9 +266,6 @@ func main() { case errors.Is(err, mqtt.ErrAuth): os.Exit(9) } - - setExitStatus(1) - return } } } @@ -311,26 +289,30 @@ func applySignals(client *mqtt.Client) { for sig := range signals { switch sig { case syscall.SIGINT: - log.Print(name, ": SIGINT received; closing connection…") - setExitStatus(130) - switch err := client.Close(); { - case err == nil, errors.Is(err, mqtt.ErrDown): - break - default: + log.Print(name, ": SIGINT received") + select { + case exitStatus <- 130: + default: // exit status already defined + } + err := client.Close() + if err != nil { log.Print(err) } case syscall.SIGTERM: - log.Print(name, ": SIGTERM received; sending disconnect…") - setExitStatus(143) + log.Print(name, ": SIGTERM received") ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) defer cancel() - switch err := client.Disconnect(ctx.Done()); err != nil { - case err == nil, errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + err := client.Disconnect(ctx.Done()) + switch { + case err == nil: + exitStatus <- 143 + case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + // exit status defined by cause break default: log.Print(err) - setExitStatus(1) + exitStatus <- 1 } } }