From fbc3958c3f575c30aec6c805bc4649b941f8fdb6 Mon Sep 17 00:00:00 2001 From: "Pascal S. de Kloe" Date: Fri, 12 Mar 2021 14:24:29 +0100 Subject: [PATCH] Better mqttc exit codes + cleanup. --- cmd/mqttc/main.go | 114 +++++++++++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 41 deletions(-) diff --git a/cmd/mqttc/main.go b/cmd/mqttc/main.go index 2cbcdcc..11403ef 100644 --- a/cmd/mqttc/main.go +++ b/cmd/mqttc/main.go @@ -111,19 +111,19 @@ func parseConfig() (clientID string, config *mqtt.Config) { var exitStatus = make(chan int, 1) -func exitCode(code int) { +func setExitStatus(code int) { select { case exitStatus <- code: default: // exit status already defined } } -func exit(code int) { - exitCode(code) - err := client.Close() - if err != nil { - log.Print(err) +func getExitStatus() (code int) { + select { + case code = <-exitStatus: + default: // stays zero } + return } var client *mqtt.Client @@ -152,7 +152,10 @@ func main() { if err != nil { log.Fatal(err) } - err = client.Publish(nil, message, *publishFlag) + + ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) + defer cancel() + err = client.Publish(ctx.Done(), message, *publishFlag) switch { case err == nil: if *verboseFlag { @@ -160,9 +163,17 @@ func main() { } case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): return + case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): + log.Printf("%s: publish timeout (%s)", name, err) + setExitStatus(1) + if err := client.Close(); err != nil { + log.Print(err) + } + return default: log.Print(err) - exit(1) + setExitStatus(1) + _ = client.Close() return } } @@ -173,20 +184,28 @@ func main() { defer cancel() err := client.SubscribeLimitAtMostOnce(ctx.Done(), subscribeFlags...) switch { - case err == nil, errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + case err == nil: + if *verboseFlag { + log.Printf("%s: subscribed to %d topic filters", name, len(subscribeFlags)) + } + case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): return case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): - log.Print(name, ": subscribe timeout") - - fallthrough + log.Printf("%s: subscribe timeout (%s)", name, err) + setExitStatus(1) + if err := client.Close(); err != nil { + log.Print(err) + } + return default: log.Print(err) - exitCode(1) + setExitStatus(1) + _ = client.Close() return } } - if *publishFlag == "" { + if *publishFlag == "" && len(subscribeFlags) == 0 { // ping exchange ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) defer cancel() @@ -197,50 +216,54 @@ func main() { case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): return case errors.Is(err, mqtt.ErrCanceled), errors.Is(err, mqtt.ErrAbandoned): - log.Print(name, ": ping timeout") - - fallthrough + log.Printf("%s: ping timeout (%s)", name, err) + setExitStatus(1) + if err := client.Close(); err != nil { + log.Print(err) + } + return default: log.Print(err) - exit(1) + setExitStatus(1) + _ = client.Close() return } } - // graceful shutdown - ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) - defer cancel() - err := client.Disconnect(ctx.Done()) - switch { - case err == nil: - exitCode(0) - case errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): - break - default: - log.Print(err) - exitCode(1) + if len(subscribeFlags) == 0 { + // graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) + defer cancel() + err := client.Disconnect(ctx.Done()) + switch { + case err == nil, errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + break + default: + log.Print(err) + setExitStatus(1) + return + } } }() + defer os.Exit(getExitStatus()) + // read routine + var big *mqtt.BigMessage for { - var big *mqtt.BigMessage message, topic, err := client.ReadSlices() switch { case err == nil: printMessage(message, topic) case errors.Is(err, mqtt.ErrClosed): - os.Exit(<-exitStatus) - - case mqtt.IsDeny(err): // illegal configuration - log.Fatal(err) + return case errors.As(err, &big): message, err := big.ReadAll() if err != nil { log.Print(err) - exit(1) + setExitStatus(1) return } printMessage(message, big.Topic) @@ -261,7 +284,7 @@ func main() { os.Exit(9) } - exit(1) + setExitStatus(1) return } } @@ -287,16 +310,25 @@ func applySignals() { switch sig { case syscall.SIGINT: log.Print(name, ": SIGINT received; closing connection…") - exit(130) + setExitStatus(130) + switch err := client.Close(); { + case err == nil, errors.Is(err, mqtt.ErrDown): + break + default: + log.Print(err) + } case syscall.SIGTERM: log.Print(name, ": SIGTERM received; sending disconnect…") - exitCode(143) + setExitStatus(143) ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) defer cancel() - err := client.Disconnect(ctx.Done()) - if err != nil && !errors.Is(err, mqtt.ErrClosed) { + switch err := client.Disconnect(ctx.Done()); err != nil { + case err == nil, errors.Is(err, mqtt.ErrClosed), errors.Is(err, mqtt.ErrDown): + break + default: log.Print(err) + setExitStatus(1) } } }