Added support for setting a name in SSH key pair.

Also refactored how new SSH key pairs are created, and how the
tests are structured.
This commit is contained in:
Stephen Fox 2019-02-04 12:07:32 -05:00
parent 5893134c61
commit b1b67ecffa
2 changed files with 203 additions and 90 deletions

View File

@ -13,6 +13,7 @@ import (
"encoding/pem"
"errors"
"strconv"
"strings"
"golang.org/x/crypto/ssh"
)
@ -67,6 +68,10 @@ type sshKeyPairBuilder interface {
// SetBits sets the key pair's bits of entropy.
SetBits(int) sshKeyPairBuilder
// SetName sets the name of the key pair. This is primarily used
// to identify the public key in the authorized_keys file.
SetName(string) sshKeyPairBuilder
// Build returns a SSH key pair.
//
// The following defaults are used if not specified:
@ -74,6 +79,7 @@ type sshKeyPairBuilder interface {
// Default bits of entropy:
// - RSA: 4096
// - ECDSA: 521
// Default name: (empty string)
Build() (sshKeyPair, error)
}
@ -83,6 +89,9 @@ type defaultSshKeyPairBuilder struct {
// bits is the resulting key pair's bits of entropy.
bits int
// name is the resulting key pair's name.
name string
}
func (o *defaultSshKeyPairBuilder) SetType(kind sshKeyPairType) sshKeyPairBuilder {
@ -95,15 +104,20 @@ func (o *defaultSshKeyPairBuilder) SetBits(bits int) sshKeyPairBuilder {
return o
}
func (o *defaultSshKeyPairBuilder) SetName(name string) sshKeyPairBuilder {
o.name = name
return o
}
func (o *defaultSshKeyPairBuilder) Build() (sshKeyPair, error) {
switch o.kind {
case rsaSsh:
return newRsaSshKeyPair(o.bits)
return o.newRsaSshKeyPair()
case ecdsaSsh:
// Default case.
}
return newEcdsaSshKeyPair(o.bits)
return o.newEcdsaSshKeyPair()
}
// sshKeyPair represents a SSH key pair.
@ -114,6 +128,10 @@ type sshKeyPair interface {
// Bits returns the bits of entropy.
Bits() int
// Name returns the key pair's name. An empty string is
// returned is no name was specified.
Name() string
// Description returns a brief description of the key pair that
// is suitable for log messages or printing.
Description() string
@ -136,6 +154,9 @@ type defaultSshKeyPair struct {
// bits is the key pair's bits of entropy.
bits int
// name is the key pair's name.
name string
// privateKeyDerBytes is the private key's bytes
// in ASN.1 DER format
privateKeyDerBytes []byte
@ -152,6 +173,10 @@ func (o defaultSshKeyPair) Bits() int {
return o.bits
}
func (o defaultSshKeyPair) Name() string {
return o.name
}
func (o defaultSshKeyPair) Description() string {
return o.kind.String() + " " + strconv.Itoa(o.bits)
}
@ -176,6 +201,14 @@ func (o defaultSshKeyPair) PrivateKeyPemBlock() []byte {
func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byte {
result := ssh.MarshalAuthorizedKey(o.publicKey)
if len(strings.TrimSpace(o.name)) > 0 {
// Awful, but the go ssh library automatically appends
// a unix new line.
result = bytes.TrimSuffix(result, unixNewLine.Bytes())
result = append(result, ' ')
result = append(result, o.name...)
}
switch nl {
case noNewLine:
result = bytes.TrimSuffix(result, unixNewLine.Bytes())
@ -197,12 +230,12 @@ func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byt
// newEcdsaSshKeyPair returns a new ECDSA SSH key pair for the given bits
// of entropy.
func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
func (o *defaultSshKeyPairBuilder) newEcdsaSshKeyPair() (sshKeyPair, error) {
var curve elliptic.Curve
switch bits {
switch o.bits {
case 0:
bits = 521
o.bits = 521
fallthrough
case 521:
curve = elliptic.P521()
@ -213,10 +246,10 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
case 224:
// Not supported by "golang.org/x/crypto/ssh".
return &defaultSshKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " +
strconv.Itoa(bits) + " bits")
strconv.Itoa(o.bits) + " bits")
default:
return &defaultSshKeyPair{}, errors.New("crypto/elliptic does not support " +
strconv.Itoa(bits) + " bits")
strconv.Itoa(o.bits) + " bits")
}
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
@ -236,7 +269,8 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
return &defaultSshKeyPair{
kind: ecdsaSsh,
bits: bits,
bits: o.bits,
name: o.name,
privateKeyDerBytes: raw,
publicKey: sshPublicKey,
}, nil
@ -244,12 +278,12 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
// newRsaSshKeyPair returns a new RSA SSH key pair for the given bits
// of entropy.
func newRsaSshKeyPair(bits int) (sshKeyPair, error) {
if bits == 0 {
bits = defaultRsaBits
func (o *defaultSshKeyPairBuilder) newRsaSshKeyPair() (sshKeyPair, error) {
if o.bits == 0 {
o.bits = defaultRsaBits
}
privateKey, err := rsa.GenerateKey(rand.Reader, bits)
privateKey, err := rsa.GenerateKey(rand.Reader, o.bits)
if err != nil {
return &defaultSshKeyPair{}, err
}
@ -261,7 +295,8 @@ func newRsaSshKeyPair(bits int) (sshKeyPair, error) {
return &defaultSshKeyPair{
kind: rsaSsh,
bits: bits,
bits: o.bits,
name: o.name,
privateKeyDerBytes: x509.MarshalPKCS1PrivateKey(privateKey),
publicKey: sshPublicKey,
}, nil

View File

@ -1,11 +1,13 @@
package common
import (
"bytes"
"crypto/rand"
"errors"
"strconv"
"testing"
"github.com/hashicorp/packer/common/uuid"
"golang.org/x/crypto/ssh"
)
@ -14,6 +16,8 @@ type expected struct {
kind sshKeyPairType
bits int
desc string
name string
data []byte
}
func (o expected) matches(kp sshKeyPair) error {
@ -29,28 +33,37 @@ func (o expected) matches(kp sshKeyPair) error {
return errors.New("expected description's value cannot be empty")
}
if len(o.data) == 0 {
return errors.New("expected random data value cannot be nothing")
}
if kp.Type() != o.kind {
return errors.New("expected key pair type to be " +
o.kind.String() + " - got '" + kp.Type().String() + "'")
return errors.New("key pair type should be " + o.kind.String() +
" - got '" + kp.Type().String() + "'")
}
if kp.Bits() != o.bits {
return errors.New("expected key pair to be " +
strconv.Itoa(o.bits) + " bits - got " + strconv.Itoa(kp.Bits()))
return errors.New("key pair bits should be " + strconv.Itoa(o.bits) +
" - got " + strconv.Itoa(kp.Bits()))
}
if len(o.name) > 0 && kp.Name() != o.name {
return errors.New("key pair name should be '" + o.name +
"' - got '" + kp.Name() + "'")
}
expDescription := kp.Type().String() + " " + strconv.Itoa(o.bits)
if kp.Description() != expDescription {
return errors.New("expected key pair description to be '" +
return errors.New("key pair description should be '" +
expDescription + "' - got '" + kp.Description() + "'")
}
err := verifyPublickeyAuthorizedKeysFormat(kp)
err := o.verifyPublicKeyAuthorizedKeysFormat(kp)
if err != nil {
return err
}
err = verifySshKeyPair(kp)
err = o.verifySshKeyPair(kp)
if err != nil {
return err
}
@ -58,76 +71,7 @@ func (o expected) matches(kp sshKeyPair) error {
return nil
}
func TestDefaultSshKeyPairBuilder_Build_Default(t *testing.T) {
kp, err := newSshKeyPairBuilder().Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_EcdsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().SetType(ecdsaSsh).Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_RsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().SetType(rsaSsh).Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: rsaSsh,
bits: 4096,
desc: "rsa 4096",
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func verifySshKeyPair(kp sshKeyPair) error {
signer, err := ssh.ParsePrivateKey(kp.PrivateKeyPemBlock())
if err != nil {
return errors.New("failed to parse private key during verification - " + err.Error())
}
data := []byte{'b', 'r', '4', 'n', '3'}
signature, err := signer.Sign(rand.Reader, data)
if err != nil {
return errors.New("failed to sign test data during verification - " + err.Error())
}
err = signer.PublicKey().Verify(data, signature)
if err != nil {
return errors.New("failed to verify test data - " + err.Error())
}
return nil
}
func verifyPublickeyAuthorizedKeysFormat(kp sshKeyPair) error {
func (o expected) verifyPublicKeyAuthorizedKeysFormat(kp sshKeyPair) error {
newLines := []newLineOption{
unixNewLine,
noNewLine,
@ -158,7 +102,141 @@ func verifyPublickeyAuthorizedKeysFormat(kp sshKeyPair) error {
return errors.New("public key in authorized keys format does not have windows new line when windows was specified")
}
}
if len(o.name) > 0 {
if len(publicKeyAk) < len(o.name) {
return errors.New("public key in authorized keys format is shorter than the key pair's name")
}
suffix := []byte{' '}
suffix = append(suffix, o.name...)
suffix = append(suffix, nl.Bytes()...)
if !bytes.HasSuffix(publicKeyAk, suffix) {
return errors.New("public key in authorized keys format with name does not have name in suffix - got '" +
string(publicKeyAk) + "'")
}
}
}
return nil
}
func (o expected) verifySshKeyPair(kp sshKeyPair) error {
signer, err := ssh.ParsePrivateKey(kp.PrivateKeyPemBlock())
if err != nil {
return errors.New("failed to parse private key during verification - " + err.Error())
}
signature, err := signer.Sign(rand.Reader, o.data)
if err != nil {
return errors.New("failed to sign test data during verification - " + err.Error())
}
err = signer.PublicKey().Verify(o.data, signature)
if err != nil {
return errors.New("failed to verify test data - " + err.Error())
}
return nil
}
func TestDefaultSshKeyPairBuilder_Build_Default(t *testing.T) {
kp, err := newSshKeyPairBuilder().Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_EcdsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().
SetType(ecdsaSsh).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_RsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().
SetType(rsaSsh).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: rsaSsh,
bits: 4096,
desc: "rsa 4096",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_NamedEcdsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
kp, err := newSshKeyPairBuilder().
SetType(ecdsaSsh).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_NamedRsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
kp, err := newSshKeyPairBuilder().
SetType(rsaSsh).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: rsaSsh,
bits: 4096,
desc: "rsa 4096",
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}