diff --git a/Makefile b/Makefile index a403f2b..6b6c4b1 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,8 @@ build: get ${GO} get -u github.com/ssh-vault/crypto ${GO} get -u github.com/ssh-vault/crypto/aead ${GO} get -u github.com/ssh-vault/crypto/oaep + ${GO} get -u github.com/keybase/go-keychain + ${GO} get -u github.com/kr/pty ${GO} build -ldflags "-X main.version=${VERSION}" -o ${BIN_NAME} cmd/ssh-vault/main.go; clean: diff --git a/get_password.go b/get_password.go new file mode 100644 index 0000000..c265a77 --- /dev/null +++ b/get_password.go @@ -0,0 +1,10 @@ +// +build !darwin + +// For platforms without managed ssh private key passwords, +// fallback to prompting the user. + +package sshvault + +func (v *vault) GetPassword() ([]byte, error) { + return v.GetPasswordPrompt() +} diff --git a/get_password_darwin.go b/get_password_darwin.go new file mode 100644 index 0000000..e439bbf --- /dev/null +++ b/get_password_darwin.go @@ -0,0 +1,38 @@ +// +build darwin + +// Apple's OpenSSH fork uses Keychain for private key passphrases. +// They're indexed by the absolute file path to the private key, +// e.g. ~/.ssh/id_rsa +// +// If the passphrase isn't in keychain, prompt the user. + +package sshvault + +import ( + "fmt" + "path/filepath" + + "github.com/keybase/go-keychain" +) + +func (v *vault) GetPassword() ([]byte, error) { + var keyPassword []byte + + key_path, err := filepath.Abs(v.key) + if err != nil { + return nil, fmt.Errorf("Error finding private key: %s", err) + } + + keyPassword, err = keychain.GetGenericPassword("SSH", key_path, "", "") + if err == nil { + return keyPassword, nil + } + + // Darn, Keychain doesn't have the password for that file. Prompt the user. + keyPassword, err = v.GetPasswordPrompt() + if err != nil { + return nil, err + } + + return keyPassword, nil +} diff --git a/get_password_darwin_test.go b/get_password_darwin_test.go new file mode 100644 index 0000000..f60980b --- /dev/null +++ b/get_password_darwin_test.go @@ -0,0 +1,97 @@ +// +build darwin + +package sshvault + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "syscall" + "testing" + + "github.com/keybase/go-keychain" + "github.com/kr/pty" +) + +func InjectKeychainPassword(path, pw string) error { + item := keychain.NewItem() + item.SetSecClass(keychain.SecClassGenericPassword) + item.SetLabel(fmt.Sprintf("SSH: %s", path)) + item.SetService("SSH") + item.SetAccount(path) + item.SetData([]byte(pw)) + item.SetSynchronizable(keychain.SynchronizableNo) + + return keychain.AddItem(item) +} + +func DeleteKeychainPassword(path string) error { + item := keychain.NewItem() + item.SetSecClass(keychain.SecClassGenericPassword) + item.SetService("SSH") + item.SetAccount(path) + + return keychain.DeleteItem(item) +} + +func TestKeychain(t *testing.T) { + key_pw := "argle-bargle" + key_bad_pw := "totally-bogus\n" + + dir, err := ioutil.TempDir("", "vault") + if err != nil { + t.Error(err) + } + defer os.RemoveAll(dir) // clean up + + tmpfile := filepath.Join(dir, "vault") + + vault, err := New("test_data/id_rsa.pub", "", "create", tmpfile) + if err != nil { + t.Error(err) + } + key_path, err := filepath.Abs(vault.key) + if err != nil { + t.Errorf("Error finding private key: %s", err) + } + err = InjectKeychainPassword(key_path, key_pw) + if err != nil { + t.Errorf("Error setting up keychain for testing: %s", err) + } + defer DeleteKeychainPassword(key_path) // clean up + + _, tty, err := pty.Open() + if err != nil { + t.Errorf("Unable to open pty: %s", err) + } + + // File Descriptor magic. GetPasswordPrompt() reads the password + // from stdin. For the test, we save stdin to a spare FD, + // point stdin at the file, run the system under test, and + // finally restore the original stdin + old_stdin, _ := syscall.Dup(int(syscall.Stdin)) + old_stdout, _ := syscall.Dup(int(syscall.Stdout)) + syscall.Dup2(int(tty.Fd()), int(syscall.Stdin)) + syscall.Dup2(int(tty.Fd()), int(syscall.Stdout)) + + go PtyWriteback(pty, key_bad_pw) + + key_pw_test, err := vault.GetPassword() + + syscall.Dup2(old_stdin, int(syscall.Stdin)) + syscall.Dup2(old_stdout, int(syscall.Stdout)) + + if err != nil { + t.Error(err) + } + if strings.Trim(string(key_pw_test), "\n") == strings.Trim(key_bad_pw, "\n") { + t.Errorf("PTY-based password prompt used, not keychain!") + } + + if strings.Trim(string(key_pw_test), "\n") != strings.Trim(key_pw, "\n") { + t.Errorf("keychain error: %s expected %s, got %s\n", key_path, key_pw, key_pw_test) + } + +} diff --git a/get_password_prompt.go b/get_password_prompt.go new file mode 100644 index 0000000..bba6186 --- /dev/null +++ b/get_password_prompt.go @@ -0,0 +1,18 @@ +package sshvault + +import ( + "fmt" + "syscall" + + "golang.org/x/crypto/ssh/terminal" +) + +func (v *vault) GetPasswordPrompt() ([]byte, error) { + fmt.Printf("Enter key password (%s): ", v.key) + keyPassword, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + return nil, err + } + + return keyPassword, nil +} diff --git a/vault_test.go b/vault_test.go index a416ac8..0ffb708 100644 --- a/vault_test.go +++ b/vault_test.go @@ -8,12 +8,23 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" + "syscall" "testing" + "time" + "github.com/kr/pty" "github.com/ssh-vault/crypto" "github.com/ssh-vault/crypto/aead" ) +// zomg this is a race condition +func PtyWriteback(pty *os.File, msg string) { + time.Sleep(500 * time.Millisecond) + defer pty.Sync() + pty.Write([]byte(msg)) +} + // These are done in one function to avoid declaring global variables in a test // file. func TestVaultFunctions(t *testing.T) { @@ -30,6 +41,35 @@ func TestVaultFunctions(t *testing.T) { t.Error(err) } + key_pw := string("argle-bargle\n") + pty, tty, err := pty.Open() + if err != nil { + t.Errorf("Unable to open pty: %s", err) + } + + // File Descriptor magic. GetPasswordPrompt() reads the password + // from stdin. For the test, we save stdin to a spare FD, + // point stdin at the file, run the system under test, and + // finally restore the original stdin + old_stdin, _ := syscall.Dup(int(syscall.Stdin)) + old_stdout, _ := syscall.Dup(int(syscall.Stdout)) + syscall.Dup2(int(tty.Fd()), int(syscall.Stdin)) + syscall.Dup2(int(tty.Fd()), int(syscall.Stdout)) + + go PtyWriteback(pty, key_pw) + + key_pw_test, err := vault.GetPasswordPrompt() + + syscall.Dup2(old_stdin, int(syscall.Stdin)) + syscall.Dup2(old_stdout, int(syscall.Stdout)) + + if err != nil { + t.Error(err) + } + if string(strings.Trim(key_pw, "\n")) != string(key_pw_test) { + t.Errorf("password prompt: expected %s, got %s\n", key_pw, key_pw_test) + } + if err = vault.PKCS8(); err != nil { t.Error(err) } diff --git a/view.go b/view.go index ecffd94..3a3c58e 100644 --- a/view.go +++ b/view.go @@ -10,12 +10,9 @@ import ( "io/ioutil" "os" "strings" - "syscall" "github.com/ssh-vault/crypto/aead" "github.com/ssh-vault/crypto/oaep" - - "golang.org/x/crypto/ssh/terminal" ) // View decrypts data and print it to stdout @@ -64,12 +61,11 @@ func (v *vault) View() ([]byte, error) { } if x509.IsEncryptedPEMBlock(block) { - fmt.Print("Enter key password: ") - keyPassword, err := terminal.ReadPassword(int(syscall.Stdin)) + keyPassword, err := v.GetPassword() if err != nil { - return nil, err + return nil, fmt.Errorf("Unable to get private key password, Decryption failed.") } - fmt.Println() + block.Bytes, err = x509.DecryptPEMBlock(block, keyPassword) if err != nil { return nil, fmt.Errorf("Password incorrect, Decryption failed.")