1
0
mirror of https://github.com/go-mqtt/mqtt.git synced 2025-08-08 22:42:05 +03:00

New -ca, -cert and -key options for mqttc(1).

This commit is contained in:
Pascal S. de Kloe
2021-07-02 18:48:10 +02:00
parent 9c94834a2b
commit e6d3d61b53
2 changed files with 99 additions and 11 deletions

View File

@@ -4,6 +4,8 @@ package main
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"flag"
"fmt"
@@ -46,10 +48,15 @@ var (
timeoutFlag = flag.Duration("timeout", 4*time.Second, "Network operation expiry.")
netFlag = flag.String("net", "tcp", "Select the network by `name`. Valid alternatives include tcp4,\ntcp6 and unix.")
tlsFlag = flag.Bool("tls", false, "Secure the connection with TLS.")
serverFlag = flag.String("server", "", "Use a specific server `name` with TLS")
userFlag = flag.String("user", "", "The user `name` may be used by the broker for authentication\nand/or authorization purposes.")
passFlag = flag.String("pass", "", "The `file` content is used as a password.")
tlsFlag = flag.Bool("tls", false, "Secure the connection with TLS.")
serverFlag = flag.String("server", "", "Use a specific server `name` with TLS")
caFlag = flag.String("ca", "", "Amend the trusted certificate authorities with a PEM `file`.")
certFlag = flag.String("cert", "", "Use a client certificate from a PEM `file` (with a corresponding\n"+bold+"-key"+clear+" option).")
keyFlag = flag.String("key", "", "Use a private key (matching the client certificate) from a PEM\n`file`.")
userFlag = flag.String("user", "", "The user `name` may be used by the broker for authentication\nand/or authorization purposes.")
passFlag = flag.String("pass", "", "The `file` content is used as a password.")
clientFlag = flag.String("client", generatedLabel, "Use a specific client `identifier`.")
@@ -62,7 +69,8 @@ var (
verboseFlag = flag.Bool("verbose", false, "Produces more output to "+italic+"standard error"+clear+" for debug purposes.")
)
func parseConfig() (clientID string, config *mqtt.Config) {
// Config collects the command arguments.
func Config() (clientID string, config *mqtt.Config) {
var addr string
switch args := flag.Args(); {
case len(args) == 0:
@@ -76,9 +84,83 @@ func parseConfig() (clientID string, config *mqtt.Config) {
log.Printf("%s: multiple address arguments %q", name, args)
os.Exit(2)
}
var TLS *tls.Config
if *tlsFlag {
TLS = new(tls.Config)
}
if *serverFlag != "" {
if TLS == nil {
log.Fatal(name, ": -server requires -tls option")
}
TLS.ServerName = *serverFlag
}
switch {
case *certFlag != "" && *keyFlag != "":
if TLS == nil {
log.Fatal(name, ": -cert requires -tls option")
}
certPEM, err := os.ReadFile(*certFlag)
if err != nil {
log.Fatal(err)
}
keyPEM, err := os.ReadFile(*keyFlag)
if err != nil {
log.Fatal(err)
}
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
log.Fatal(name, ": unusable -cert and -key content; ", err)
}
TLS.Certificates = append(TLS.Certificates, cert)
case *certFlag != "":
log.Fatal(name, ": -cert requires -key option")
case *keyFlag != "":
log.Fatal(name, ": -key requires -cert option")
}
if *caFlag != "" {
if TLS == nil {
log.Fatal(name, ": -ca requires -tls option")
}
if certs, err := x509.SystemCertPool(); err != nil {
log.Print(name, ": system certificates unavailable; ", err)
TLS.RootCAs = x509.NewCertPool()
} else {
TLS.RootCAs = certs
}
text, err := os.ReadFile(*caFlag)
if err != nil {
log.Fatal(err)
}
for n := 1; ; n++ {
var block *pem.Block
block, text = pem.Decode(text)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
log.Printf("%s: ignoring PEM block № %d of type %q in %s", name, n, block.Type, *caFlag)
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Printf("%s: ignoring PEM block № %d in %s; %s", name, n, *caFlag, err)
continue
}
TLS.RootCAs.AddCert(cert)
}
}
if _, _, err := net.SplitHostPort(addr); err != nil {
port := "1883"
if *tlsFlag {
if TLS != nil {
port = "8883"
}
addr = net.JoinHostPort(addr, port)
@@ -101,10 +183,8 @@ func parseConfig() (clientID string, config *mqtt.Config) {
config.Password = bytes
}
if *tlsFlag {
config.Dialer = mqtt.NewTLSDialer(*netFlag, addr, &tls.Config{
ServerName: *serverFlag,
})
if TLS != nil {
config.Dialer = mqtt.NewTLSDialer(*netFlag, addr, TLS)
} else {
config.Dialer = mqtt.NewDialer(*netFlag, addr)
}
@@ -135,7 +215,7 @@ func main() {
log.SetOutput(io.Discard)
}
clientID, config := parseConfig()
clientID, config := Config()
client, err := mqtt.VolatileSession(clientID, config)
if err != nil {
log.Fatal(err)