Added support for setting a name in SSH key pair.

Also refactored how new SSH key pairs are created, and how the
tests are structured.
This commit is contained in:
Stephen Fox 2019-02-04 12:07:32 -05:00
parent 5893134c61
commit b1b67ecffa
2 changed files with 203 additions and 90 deletions

View File

@ -13,6 +13,7 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"strconv" "strconv"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -67,6 +68,10 @@ type sshKeyPairBuilder interface {
// SetBits sets the key pair's bits of entropy. // SetBits sets the key pair's bits of entropy.
SetBits(int) sshKeyPairBuilder SetBits(int) sshKeyPairBuilder
// SetName sets the name of the key pair. This is primarily used
// to identify the public key in the authorized_keys file.
SetName(string) sshKeyPairBuilder
// Build returns a SSH key pair. // Build returns a SSH key pair.
// //
// The following defaults are used if not specified: // The following defaults are used if not specified:
@ -74,6 +79,7 @@ type sshKeyPairBuilder interface {
// Default bits of entropy: // Default bits of entropy:
// - RSA: 4096 // - RSA: 4096
// - ECDSA: 521 // - ECDSA: 521
// Default name: (empty string)
Build() (sshKeyPair, error) Build() (sshKeyPair, error)
} }
@ -83,6 +89,9 @@ type defaultSshKeyPairBuilder struct {
// bits is the resulting key pair's bits of entropy. // bits is the resulting key pair's bits of entropy.
bits int bits int
// name is the resulting key pair's name.
name string
} }
func (o *defaultSshKeyPairBuilder) SetType(kind sshKeyPairType) sshKeyPairBuilder { func (o *defaultSshKeyPairBuilder) SetType(kind sshKeyPairType) sshKeyPairBuilder {
@ -95,15 +104,20 @@ func (o *defaultSshKeyPairBuilder) SetBits(bits int) sshKeyPairBuilder {
return o return o
} }
func (o *defaultSshKeyPairBuilder) SetName(name string) sshKeyPairBuilder {
o.name = name
return o
}
func (o *defaultSshKeyPairBuilder) Build() (sshKeyPair, error) { func (o *defaultSshKeyPairBuilder) Build() (sshKeyPair, error) {
switch o.kind { switch o.kind {
case rsaSsh: case rsaSsh:
return newRsaSshKeyPair(o.bits) return o.newRsaSshKeyPair()
case ecdsaSsh: case ecdsaSsh:
// Default case. // Default case.
} }
return newEcdsaSshKeyPair(o.bits) return o.newEcdsaSshKeyPair()
} }
// sshKeyPair represents a SSH key pair. // sshKeyPair represents a SSH key pair.
@ -114,6 +128,10 @@ type sshKeyPair interface {
// Bits returns the bits of entropy. // Bits returns the bits of entropy.
Bits() int Bits() int
// Name returns the key pair's name. An empty string is
// returned is no name was specified.
Name() string
// Description returns a brief description of the key pair that // Description returns a brief description of the key pair that
// is suitable for log messages or printing. // is suitable for log messages or printing.
Description() string Description() string
@ -136,6 +154,9 @@ type defaultSshKeyPair struct {
// bits is the key pair's bits of entropy. // bits is the key pair's bits of entropy.
bits int bits int
// name is the key pair's name.
name string
// privateKeyDerBytes is the private key's bytes // privateKeyDerBytes is the private key's bytes
// in ASN.1 DER format // in ASN.1 DER format
privateKeyDerBytes []byte privateKeyDerBytes []byte
@ -152,6 +173,10 @@ func (o defaultSshKeyPair) Bits() int {
return o.bits return o.bits
} }
func (o defaultSshKeyPair) Name() string {
return o.name
}
func (o defaultSshKeyPair) Description() string { func (o defaultSshKeyPair) Description() string {
return o.kind.String() + " " + strconv.Itoa(o.bits) return o.kind.String() + " " + strconv.Itoa(o.bits)
} }
@ -176,6 +201,14 @@ func (o defaultSshKeyPair) PrivateKeyPemBlock() []byte {
func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byte { func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byte {
result := ssh.MarshalAuthorizedKey(o.publicKey) result := ssh.MarshalAuthorizedKey(o.publicKey)
if len(strings.TrimSpace(o.name)) > 0 {
// Awful, but the go ssh library automatically appends
// a unix new line.
result = bytes.TrimSuffix(result, unixNewLine.Bytes())
result = append(result, ' ')
result = append(result, o.name...)
}
switch nl { switch nl {
case noNewLine: case noNewLine:
result = bytes.TrimSuffix(result, unixNewLine.Bytes()) result = bytes.TrimSuffix(result, unixNewLine.Bytes())
@ -197,12 +230,12 @@ func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byt
// newEcdsaSshKeyPair returns a new ECDSA SSH key pair for the given bits // newEcdsaSshKeyPair returns a new ECDSA SSH key pair for the given bits
// of entropy. // of entropy.
func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) { func (o *defaultSshKeyPairBuilder) newEcdsaSshKeyPair() (sshKeyPair, error) {
var curve elliptic.Curve var curve elliptic.Curve
switch bits { switch o.bits {
case 0: case 0:
bits = 521 o.bits = 521
fallthrough fallthrough
case 521: case 521:
curve = elliptic.P521() curve = elliptic.P521()
@ -213,10 +246,10 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
case 224: case 224:
// Not supported by "golang.org/x/crypto/ssh". // Not supported by "golang.org/x/crypto/ssh".
return &defaultSshKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " + return &defaultSshKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " +
strconv.Itoa(bits) + " bits") strconv.Itoa(o.bits) + " bits")
default: default:
return &defaultSshKeyPair{}, errors.New("crypto/elliptic does not support " + return &defaultSshKeyPair{}, errors.New("crypto/elliptic does not support " +
strconv.Itoa(bits) + " bits") strconv.Itoa(o.bits) + " bits")
} }
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader) privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
@ -236,7 +269,8 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
return &defaultSshKeyPair{ return &defaultSshKeyPair{
kind: ecdsaSsh, kind: ecdsaSsh,
bits: bits, bits: o.bits,
name: o.name,
privateKeyDerBytes: raw, privateKeyDerBytes: raw,
publicKey: sshPublicKey, publicKey: sshPublicKey,
}, nil }, nil
@ -244,12 +278,12 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) {
// newRsaSshKeyPair returns a new RSA SSH key pair for the given bits // newRsaSshKeyPair returns a new RSA SSH key pair for the given bits
// of entropy. // of entropy.
func newRsaSshKeyPair(bits int) (sshKeyPair, error) { func (o *defaultSshKeyPairBuilder) newRsaSshKeyPair() (sshKeyPair, error) {
if bits == 0 { if o.bits == 0 {
bits = defaultRsaBits o.bits = defaultRsaBits
} }
privateKey, err := rsa.GenerateKey(rand.Reader, bits) privateKey, err := rsa.GenerateKey(rand.Reader, o.bits)
if err != nil { if err != nil {
return &defaultSshKeyPair{}, err return &defaultSshKeyPair{}, err
} }
@ -261,7 +295,8 @@ func newRsaSshKeyPair(bits int) (sshKeyPair, error) {
return &defaultSshKeyPair{ return &defaultSshKeyPair{
kind: rsaSsh, kind: rsaSsh,
bits: bits, bits: o.bits,
name: o.name,
privateKeyDerBytes: x509.MarshalPKCS1PrivateKey(privateKey), privateKeyDerBytes: x509.MarshalPKCS1PrivateKey(privateKey),
publicKey: sshPublicKey, publicKey: sshPublicKey,
}, nil }, nil

View File

@ -1,11 +1,13 @@
package common package common
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"errors" "errors"
"strconv" "strconv"
"testing" "testing"
"github.com/hashicorp/packer/common/uuid"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -14,6 +16,8 @@ type expected struct {
kind sshKeyPairType kind sshKeyPairType
bits int bits int
desc string desc string
name string
data []byte
} }
func (o expected) matches(kp sshKeyPair) error { func (o expected) matches(kp sshKeyPair) error {
@ -29,28 +33,37 @@ func (o expected) matches(kp sshKeyPair) error {
return errors.New("expected description's value cannot be empty") return errors.New("expected description's value cannot be empty")
} }
if len(o.data) == 0 {
return errors.New("expected random data value cannot be nothing")
}
if kp.Type() != o.kind { if kp.Type() != o.kind {
return errors.New("expected key pair type to be " + return errors.New("key pair type should be " + o.kind.String() +
o.kind.String() + " - got '" + kp.Type().String() + "'") " - got '" + kp.Type().String() + "'")
} }
if kp.Bits() != o.bits { if kp.Bits() != o.bits {
return errors.New("expected key pair to be " + return errors.New("key pair bits should be " + strconv.Itoa(o.bits) +
strconv.Itoa(o.bits) + " bits - got " + strconv.Itoa(kp.Bits())) " - got " + strconv.Itoa(kp.Bits()))
}
if len(o.name) > 0 && kp.Name() != o.name {
return errors.New("key pair name should be '" + o.name +
"' - got '" + kp.Name() + "'")
} }
expDescription := kp.Type().String() + " " + strconv.Itoa(o.bits) expDescription := kp.Type().String() + " " + strconv.Itoa(o.bits)
if kp.Description() != expDescription { if kp.Description() != expDescription {
return errors.New("expected key pair description to be '" + return errors.New("key pair description should be '" +
expDescription + "' - got '" + kp.Description() + "'") expDescription + "' - got '" + kp.Description() + "'")
} }
err := verifyPublickeyAuthorizedKeysFormat(kp) err := o.verifyPublicKeyAuthorizedKeysFormat(kp)
if err != nil { if err != nil {
return err return err
} }
err = verifySshKeyPair(kp) err = o.verifySshKeyPair(kp)
if err != nil { if err != nil {
return err return err
} }
@ -58,76 +71,7 @@ func (o expected) matches(kp sshKeyPair) error {
return nil return nil
} }
func TestDefaultSshKeyPairBuilder_Build_Default(t *testing.T) { func (o expected) verifyPublicKeyAuthorizedKeysFormat(kp sshKeyPair) error {
kp, err := newSshKeyPairBuilder().Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_EcdsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().SetType(ecdsaSsh).Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_RsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().SetType(rsaSsh).Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: rsaSsh,
bits: 4096,
desc: "rsa 4096",
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func verifySshKeyPair(kp sshKeyPair) error {
signer, err := ssh.ParsePrivateKey(kp.PrivateKeyPemBlock())
if err != nil {
return errors.New("failed to parse private key during verification - " + err.Error())
}
data := []byte{'b', 'r', '4', 'n', '3'}
signature, err := signer.Sign(rand.Reader, data)
if err != nil {
return errors.New("failed to sign test data during verification - " + err.Error())
}
err = signer.PublicKey().Verify(data, signature)
if err != nil {
return errors.New("failed to verify test data - " + err.Error())
}
return nil
}
func verifyPublickeyAuthorizedKeysFormat(kp sshKeyPair) error {
newLines := []newLineOption{ newLines := []newLineOption{
unixNewLine, unixNewLine,
noNewLine, noNewLine,
@ -158,7 +102,141 @@ func verifyPublickeyAuthorizedKeysFormat(kp sshKeyPair) error {
return errors.New("public key in authorized keys format does not have windows new line when windows was specified") return errors.New("public key in authorized keys format does not have windows new line when windows was specified")
} }
} }
if len(o.name) > 0 {
if len(publicKeyAk) < len(o.name) {
return errors.New("public key in authorized keys format is shorter than the key pair's name")
}
suffix := []byte{' '}
suffix = append(suffix, o.name...)
suffix = append(suffix, nl.Bytes()...)
if !bytes.HasSuffix(publicKeyAk, suffix) {
return errors.New("public key in authorized keys format with name does not have name in suffix - got '" +
string(publicKeyAk) + "'")
}
}
} }
return nil return nil
} }
func (o expected) verifySshKeyPair(kp sshKeyPair) error {
signer, err := ssh.ParsePrivateKey(kp.PrivateKeyPemBlock())
if err != nil {
return errors.New("failed to parse private key during verification - " + err.Error())
}
signature, err := signer.Sign(rand.Reader, o.data)
if err != nil {
return errors.New("failed to sign test data during verification - " + err.Error())
}
err = signer.PublicKey().Verify(o.data, signature)
if err != nil {
return errors.New("failed to verify test data - " + err.Error())
}
return nil
}
func TestDefaultSshKeyPairBuilder_Build_Default(t *testing.T) {
kp, err := newSshKeyPairBuilder().Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_EcdsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().
SetType(ecdsaSsh).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_RsaDefault(t *testing.T) {
kp, err := newSshKeyPairBuilder().
SetType(rsaSsh).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: rsaSsh,
bits: 4096,
desc: "rsa 4096",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_NamedEcdsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
kp, err := newSshKeyPairBuilder().
SetType(ecdsaSsh).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: ecdsaSsh,
bits: 521,
desc: "ecdsa 521",
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultSshKeyPairBuilder_Build_NamedRsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
kp, err := newSshKeyPairBuilder().
SetType(rsaSsh).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: rsaSsh,
bits: 4096,
desc: "rsa 4096",
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}