mirror of
https://github.com/ssh-vault/ssh-vault.git
synced 2025-07-29 18:01:12 +03:00
Merge pull request #12 from tmaher/macos-keychain
Use macOS Keychain for private key passphrase
This commit is contained in:
2
Makefile
2
Makefile
@ -18,6 +18,8 @@ build: get
|
||||
${GO} get -u github.com/ssh-vault/crypto
|
||||
${GO} get -u github.com/ssh-vault/crypto/aead
|
||||
${GO} get -u github.com/ssh-vault/crypto/oaep
|
||||
${GO} get -u github.com/keybase/go-keychain
|
||||
${GO} get -u github.com/kr/pty
|
||||
${GO} build -ldflags "-X main.version=${VERSION}" -o ${BIN_NAME} cmd/ssh-vault/main.go;
|
||||
|
||||
clean:
|
||||
|
10
get_password.go
Normal file
10
get_password.go
Normal file
@ -0,0 +1,10 @@
|
||||
// +build !darwin
|
||||
|
||||
// For platforms without managed ssh private key passwords,
|
||||
// fallback to prompting the user.
|
||||
|
||||
package sshvault
|
||||
|
||||
func (v *vault) GetPassword() ([]byte, error) {
|
||||
return v.GetPasswordPrompt()
|
||||
}
|
38
get_password_darwin.go
Normal file
38
get_password_darwin.go
Normal file
@ -0,0 +1,38 @@
|
||||
// +build darwin
|
||||
|
||||
// Apple's OpenSSH fork uses Keychain for private key passphrases.
|
||||
// They're indexed by the absolute file path to the private key,
|
||||
// e.g. ~/.ssh/id_rsa
|
||||
//
|
||||
// If the passphrase isn't in keychain, prompt the user.
|
||||
|
||||
package sshvault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/keybase/go-keychain"
|
||||
)
|
||||
|
||||
func (v *vault) GetPassword() ([]byte, error) {
|
||||
var keyPassword []byte
|
||||
|
||||
key_path, err := filepath.Abs(v.key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error finding private key: %s", err)
|
||||
}
|
||||
|
||||
keyPassword, err = keychain.GetGenericPassword("SSH", key_path, "", "")
|
||||
if err == nil {
|
||||
return keyPassword, nil
|
||||
}
|
||||
|
||||
// Darn, Keychain doesn't have the password for that file. Prompt the user.
|
||||
keyPassword, err = v.GetPasswordPrompt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keyPassword, nil
|
||||
}
|
97
get_password_darwin_test.go
Normal file
97
get_password_darwin_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
// +build darwin
|
||||
|
||||
package sshvault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/keybase/go-keychain"
|
||||
"github.com/kr/pty"
|
||||
)
|
||||
|
||||
func InjectKeychainPassword(path, pw string) error {
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetLabel(fmt.Sprintf("SSH: %s", path))
|
||||
item.SetService("SSH")
|
||||
item.SetAccount(path)
|
||||
item.SetData([]byte(pw))
|
||||
item.SetSynchronizable(keychain.SynchronizableNo)
|
||||
|
||||
return keychain.AddItem(item)
|
||||
}
|
||||
|
||||
func DeleteKeychainPassword(path string) error {
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetService("SSH")
|
||||
item.SetAccount(path)
|
||||
|
||||
return keychain.DeleteItem(item)
|
||||
}
|
||||
|
||||
func TestKeychain(t *testing.T) {
|
||||
key_pw := "argle-bargle"
|
||||
key_bad_pw := "totally-bogus\n"
|
||||
|
||||
dir, err := ioutil.TempDir("", "vault")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer os.RemoveAll(dir) // clean up
|
||||
|
||||
tmpfile := filepath.Join(dir, "vault")
|
||||
|
||||
vault, err := New("test_data/id_rsa.pub", "", "create", tmpfile)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
key_path, err := filepath.Abs(vault.key)
|
||||
if err != nil {
|
||||
t.Errorf("Error finding private key: %s", err)
|
||||
}
|
||||
err = InjectKeychainPassword(key_path, key_pw)
|
||||
if err != nil {
|
||||
t.Errorf("Error setting up keychain for testing: %s", err)
|
||||
}
|
||||
defer DeleteKeychainPassword(key_path) // clean up
|
||||
|
||||
_, tty, err := pty.Open()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to open pty: %s", err)
|
||||
}
|
||||
|
||||
// File Descriptor magic. GetPasswordPrompt() reads the password
|
||||
// from stdin. For the test, we save stdin to a spare FD,
|
||||
// point stdin at the file, run the system under test, and
|
||||
// finally restore the original stdin
|
||||
old_stdin, _ := syscall.Dup(int(syscall.Stdin))
|
||||
old_stdout, _ := syscall.Dup(int(syscall.Stdout))
|
||||
syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
|
||||
syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))
|
||||
|
||||
go PtyWriteback(pty, key_bad_pw)
|
||||
|
||||
key_pw_test, err := vault.GetPassword()
|
||||
|
||||
syscall.Dup2(old_stdin, int(syscall.Stdin))
|
||||
syscall.Dup2(old_stdout, int(syscall.Stdout))
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if strings.Trim(string(key_pw_test), "\n") == strings.Trim(key_bad_pw, "\n") {
|
||||
t.Errorf("PTY-based password prompt used, not keychain!")
|
||||
}
|
||||
|
||||
if strings.Trim(string(key_pw_test), "\n") != strings.Trim(key_pw, "\n") {
|
||||
t.Errorf("keychain error: %s expected %s, got %s\n", key_path, key_pw, key_pw_test)
|
||||
}
|
||||
|
||||
}
|
18
get_password_prompt.go
Normal file
18
get_password_prompt.go
Normal file
@ -0,0 +1,18 @@
|
||||
package sshvault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
func (v *vault) GetPasswordPrompt() ([]byte, error) {
|
||||
fmt.Printf("Enter key password (%s): ", v.key)
|
||||
keyPassword, err := terminal.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keyPassword, nil
|
||||
}
|
@ -8,12 +8,23 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kr/pty"
|
||||
"github.com/ssh-vault/crypto"
|
||||
"github.com/ssh-vault/crypto/aead"
|
||||
)
|
||||
|
||||
// zomg this is a race condition
|
||||
func PtyWriteback(pty *os.File, msg string) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
defer pty.Sync()
|
||||
pty.Write([]byte(msg))
|
||||
}
|
||||
|
||||
// These are done in one function to avoid declaring global variables in a test
|
||||
// file.
|
||||
func TestVaultFunctions(t *testing.T) {
|
||||
@ -30,6 +41,35 @@ func TestVaultFunctions(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
key_pw := string("argle-bargle\n")
|
||||
pty, tty, err := pty.Open()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to open pty: %s", err)
|
||||
}
|
||||
|
||||
// File Descriptor magic. GetPasswordPrompt() reads the password
|
||||
// from stdin. For the test, we save stdin to a spare FD,
|
||||
// point stdin at the file, run the system under test, and
|
||||
// finally restore the original stdin
|
||||
old_stdin, _ := syscall.Dup(int(syscall.Stdin))
|
||||
old_stdout, _ := syscall.Dup(int(syscall.Stdout))
|
||||
syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
|
||||
syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))
|
||||
|
||||
go PtyWriteback(pty, key_pw)
|
||||
|
||||
key_pw_test, err := vault.GetPasswordPrompt()
|
||||
|
||||
syscall.Dup2(old_stdin, int(syscall.Stdin))
|
||||
syscall.Dup2(old_stdout, int(syscall.Stdout))
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if string(strings.Trim(key_pw, "\n")) != string(key_pw_test) {
|
||||
t.Errorf("password prompt: expected %s, got %s\n", key_pw, key_pw_test)
|
||||
}
|
||||
|
||||
if err = vault.PKCS8(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
10
view.go
10
view.go
@ -10,12 +10,9 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/ssh-vault/crypto/aead"
|
||||
"github.com/ssh-vault/crypto/oaep"
|
||||
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
// View decrypts data and print it to stdout
|
||||
@ -64,12 +61,11 @@ func (v *vault) View() ([]byte, error) {
|
||||
}
|
||||
|
||||
if x509.IsEncryptedPEMBlock(block) {
|
||||
fmt.Print("Enter key password: ")
|
||||
keyPassword, err := terminal.ReadPassword(int(syscall.Stdin))
|
||||
keyPassword, err := v.GetPassword()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Unable to get private key password, Decryption failed.")
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
block.Bytes, err = x509.DecryptPEMBlock(block, keyPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Password incorrect, Decryption failed.")
|
||||
|
Reference in New Issue
Block a user