1
0
mirror of https://github.com/ssh-vault/ssh-vault.git synced 2025-07-29 18:01:12 +03:00

Merge pull request #12 from tmaher/macos-keychain

Use macOS Keychain for private key passphrase
This commit is contained in:
Nicolas Embriz
2016-12-20 11:24:26 +01:00
committed by GitHub
7 changed files with 208 additions and 7 deletions

View File

@ -18,6 +18,8 @@ build: get
${GO} get -u github.com/ssh-vault/crypto
${GO} get -u github.com/ssh-vault/crypto/aead
${GO} get -u github.com/ssh-vault/crypto/oaep
${GO} get -u github.com/keybase/go-keychain
${GO} get -u github.com/kr/pty
${GO} build -ldflags "-X main.version=${VERSION}" -o ${BIN_NAME} cmd/ssh-vault/main.go;
clean:

10
get_password.go Normal file
View File

@ -0,0 +1,10 @@
// +build !darwin
// For platforms without managed ssh private key passwords,
// fallback to prompting the user.
package sshvault
func (v *vault) GetPassword() ([]byte, error) {
return v.GetPasswordPrompt()
}

38
get_password_darwin.go Normal file
View File

@ -0,0 +1,38 @@
// +build darwin
// Apple's OpenSSH fork uses Keychain for private key passphrases.
// They're indexed by the absolute file path to the private key,
// e.g. ~/.ssh/id_rsa
//
// If the passphrase isn't in keychain, prompt the user.
package sshvault
import (
"fmt"
"path/filepath"
"github.com/keybase/go-keychain"
)
func (v *vault) GetPassword() ([]byte, error) {
var keyPassword []byte
key_path, err := filepath.Abs(v.key)
if err != nil {
return nil, fmt.Errorf("Error finding private key: %s", err)
}
keyPassword, err = keychain.GetGenericPassword("SSH", key_path, "", "")
if err == nil {
return keyPassword, nil
}
// Darn, Keychain doesn't have the password for that file. Prompt the user.
keyPassword, err = v.GetPasswordPrompt()
if err != nil {
return nil, err
}
return keyPassword, nil
}

View File

@ -0,0 +1,97 @@
// +build darwin
package sshvault
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"syscall"
"testing"
"github.com/keybase/go-keychain"
"github.com/kr/pty"
)
func InjectKeychainPassword(path, pw string) error {
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetLabel(fmt.Sprintf("SSH: %s", path))
item.SetService("SSH")
item.SetAccount(path)
item.SetData([]byte(pw))
item.SetSynchronizable(keychain.SynchronizableNo)
return keychain.AddItem(item)
}
func DeleteKeychainPassword(path string) error {
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetService("SSH")
item.SetAccount(path)
return keychain.DeleteItem(item)
}
func TestKeychain(t *testing.T) {
key_pw := "argle-bargle"
key_bad_pw := "totally-bogus\n"
dir, err := ioutil.TempDir("", "vault")
if err != nil {
t.Error(err)
}
defer os.RemoveAll(dir) // clean up
tmpfile := filepath.Join(dir, "vault")
vault, err := New("test_data/id_rsa.pub", "", "create", tmpfile)
if err != nil {
t.Error(err)
}
key_path, err := filepath.Abs(vault.key)
if err != nil {
t.Errorf("Error finding private key: %s", err)
}
err = InjectKeychainPassword(key_path, key_pw)
if err != nil {
t.Errorf("Error setting up keychain for testing: %s", err)
}
defer DeleteKeychainPassword(key_path) // clean up
_, tty, err := pty.Open()
if err != nil {
t.Errorf("Unable to open pty: %s", err)
}
// File Descriptor magic. GetPasswordPrompt() reads the password
// from stdin. For the test, we save stdin to a spare FD,
// point stdin at the file, run the system under test, and
// finally restore the original stdin
old_stdin, _ := syscall.Dup(int(syscall.Stdin))
old_stdout, _ := syscall.Dup(int(syscall.Stdout))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))
go PtyWriteback(pty, key_bad_pw)
key_pw_test, err := vault.GetPassword()
syscall.Dup2(old_stdin, int(syscall.Stdin))
syscall.Dup2(old_stdout, int(syscall.Stdout))
if err != nil {
t.Error(err)
}
if strings.Trim(string(key_pw_test), "\n") == strings.Trim(key_bad_pw, "\n") {
t.Errorf("PTY-based password prompt used, not keychain!")
}
if strings.Trim(string(key_pw_test), "\n") != strings.Trim(key_pw, "\n") {
t.Errorf("keychain error: %s expected %s, got %s\n", key_path, key_pw, key_pw_test)
}
}

18
get_password_prompt.go Normal file
View File

@ -0,0 +1,18 @@
package sshvault
import (
"fmt"
"syscall"
"golang.org/x/crypto/ssh/terminal"
)
func (v *vault) GetPasswordPrompt() ([]byte, error) {
fmt.Printf("Enter key password (%s): ", v.key)
keyPassword, err := terminal.ReadPassword(int(syscall.Stdin))
if err != nil {
return nil, err
}
return keyPassword, nil
}

View File

@ -8,12 +8,23 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"strings"
"syscall"
"testing"
"time"
"github.com/kr/pty"
"github.com/ssh-vault/crypto"
"github.com/ssh-vault/crypto/aead"
)
// zomg this is a race condition
func PtyWriteback(pty *os.File, msg string) {
time.Sleep(500 * time.Millisecond)
defer pty.Sync()
pty.Write([]byte(msg))
}
// These are done in one function to avoid declaring global variables in a test
// file.
func TestVaultFunctions(t *testing.T) {
@ -30,6 +41,35 @@ func TestVaultFunctions(t *testing.T) {
t.Error(err)
}
key_pw := string("argle-bargle\n")
pty, tty, err := pty.Open()
if err != nil {
t.Errorf("Unable to open pty: %s", err)
}
// File Descriptor magic. GetPasswordPrompt() reads the password
// from stdin. For the test, we save stdin to a spare FD,
// point stdin at the file, run the system under test, and
// finally restore the original stdin
old_stdin, _ := syscall.Dup(int(syscall.Stdin))
old_stdout, _ := syscall.Dup(int(syscall.Stdout))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))
go PtyWriteback(pty, key_pw)
key_pw_test, err := vault.GetPasswordPrompt()
syscall.Dup2(old_stdin, int(syscall.Stdin))
syscall.Dup2(old_stdout, int(syscall.Stdout))
if err != nil {
t.Error(err)
}
if string(strings.Trim(key_pw, "\n")) != string(key_pw_test) {
t.Errorf("password prompt: expected %s, got %s\n", key_pw, key_pw_test)
}
if err = vault.PKCS8(); err != nil {
t.Error(err)
}

10
view.go
View File

@ -10,12 +10,9 @@ import (
"io/ioutil"
"os"
"strings"
"syscall"
"github.com/ssh-vault/crypto/aead"
"github.com/ssh-vault/crypto/oaep"
"golang.org/x/crypto/ssh/terminal"
)
// View decrypts data and print it to stdout
@ -64,12 +61,11 @@ func (v *vault) View() ([]byte, error) {
}
if x509.IsEncryptedPEMBlock(block) {
fmt.Print("Enter key password: ")
keyPassword, err := terminal.ReadPassword(int(syscall.Stdin))
keyPassword, err := v.GetPassword()
if err != nil {
return nil, err
return nil, fmt.Errorf("Unable to get private key password, Decryption failed.")
}
fmt.Println()
block.Bytes, err = x509.DecryptPEMBlock(block, keyPassword)
if err != nil {
return nil, fmt.Errorf("Password incorrect, Decryption failed.")