diff --git a/builder/virtualbox/common/sshkeypair.go b/builder/virtualbox/common/sshkeypair.go index fe5ba449c..be4f26d30 100644 --- a/builder/virtualbox/common/sshkeypair.go +++ b/builder/virtualbox/common/sshkeypair.go @@ -13,6 +13,7 @@ import ( "encoding/pem" "errors" "strconv" + "strings" "golang.org/x/crypto/ssh" ) @@ -67,6 +68,10 @@ type sshKeyPairBuilder interface { // SetBits sets the key pair's bits of entropy. SetBits(int) sshKeyPairBuilder + // SetName sets the name of the key pair. This is primarily used + // to identify the public key in the authorized_keys file. + SetName(string) sshKeyPairBuilder + // Build returns a SSH key pair. // // The following defaults are used if not specified: @@ -74,6 +79,7 @@ type sshKeyPairBuilder interface { // Default bits of entropy: // - RSA: 4096 // - ECDSA: 521 + // Default name: (empty string) Build() (sshKeyPair, error) } @@ -83,6 +89,9 @@ type defaultSshKeyPairBuilder struct { // bits is the resulting key pair's bits of entropy. bits int + + // name is the resulting key pair's name. + name string } func (o *defaultSshKeyPairBuilder) SetType(kind sshKeyPairType) sshKeyPairBuilder { @@ -95,15 +104,20 @@ func (o *defaultSshKeyPairBuilder) SetBits(bits int) sshKeyPairBuilder { return o } +func (o *defaultSshKeyPairBuilder) SetName(name string) sshKeyPairBuilder { + o.name = name + return o +} + func (o *defaultSshKeyPairBuilder) Build() (sshKeyPair, error) { switch o.kind { case rsaSsh: - return newRsaSshKeyPair(o.bits) + return o.newRsaSshKeyPair() case ecdsaSsh: // Default case. } - return newEcdsaSshKeyPair(o.bits) + return o.newEcdsaSshKeyPair() } // sshKeyPair represents a SSH key pair. @@ -114,6 +128,10 @@ type sshKeyPair interface { // Bits returns the bits of entropy. Bits() int + // Name returns the key pair's name. An empty string is + // returned is no name was specified. + Name() string + // Description returns a brief description of the key pair that // is suitable for log messages or printing. Description() string @@ -136,6 +154,9 @@ type defaultSshKeyPair struct { // bits is the key pair's bits of entropy. bits int + // name is the key pair's name. + name string + // privateKeyDerBytes is the private key's bytes // in ASN.1 DER format privateKeyDerBytes []byte @@ -152,6 +173,10 @@ func (o defaultSshKeyPair) Bits() int { return o.bits } +func (o defaultSshKeyPair) Name() string { + return o.name +} + func (o defaultSshKeyPair) Description() string { return o.kind.String() + " " + strconv.Itoa(o.bits) } @@ -176,6 +201,14 @@ func (o defaultSshKeyPair) PrivateKeyPemBlock() []byte { func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byte { result := ssh.MarshalAuthorizedKey(o.publicKey) + if len(strings.TrimSpace(o.name)) > 0 { + // Awful, but the go ssh library automatically appends + // a unix new line. + result = bytes.TrimSuffix(result, unixNewLine.Bytes()) + result = append(result, ' ') + result = append(result, o.name...) + } + switch nl { case noNewLine: result = bytes.TrimSuffix(result, unixNewLine.Bytes()) @@ -197,12 +230,12 @@ func (o defaultSshKeyPair) PublicKeyAuthorizedKeysFormat(nl newLineOption) []byt // newEcdsaSshKeyPair returns a new ECDSA SSH key pair for the given bits // of entropy. -func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) { +func (o *defaultSshKeyPairBuilder) newEcdsaSshKeyPair() (sshKeyPair, error) { var curve elliptic.Curve - switch bits { + switch o.bits { case 0: - bits = 521 + o.bits = 521 fallthrough case 521: curve = elliptic.P521() @@ -213,10 +246,10 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) { case 224: // Not supported by "golang.org/x/crypto/ssh". return &defaultSshKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " + - strconv.Itoa(bits) + " bits") + strconv.Itoa(o.bits) + " bits") default: return &defaultSshKeyPair{}, errors.New("crypto/elliptic does not support " + - strconv.Itoa(bits) + " bits") + strconv.Itoa(o.bits) + " bits") } privateKey, err := ecdsa.GenerateKey(curve, rand.Reader) @@ -236,7 +269,8 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) { return &defaultSshKeyPair{ kind: ecdsaSsh, - bits: bits, + bits: o.bits, + name: o.name, privateKeyDerBytes: raw, publicKey: sshPublicKey, }, nil @@ -244,12 +278,12 @@ func newEcdsaSshKeyPair(bits int) (sshKeyPair, error) { // newRsaSshKeyPair returns a new RSA SSH key pair for the given bits // of entropy. -func newRsaSshKeyPair(bits int) (sshKeyPair, error) { - if bits == 0 { - bits = defaultRsaBits +func (o *defaultSshKeyPairBuilder) newRsaSshKeyPair() (sshKeyPair, error) { + if o.bits == 0 { + o.bits = defaultRsaBits } - privateKey, err := rsa.GenerateKey(rand.Reader, bits) + privateKey, err := rsa.GenerateKey(rand.Reader, o.bits) if err != nil { return &defaultSshKeyPair{}, err } @@ -261,7 +295,8 @@ func newRsaSshKeyPair(bits int) (sshKeyPair, error) { return &defaultSshKeyPair{ kind: rsaSsh, - bits: bits, + bits: o.bits, + name: o.name, privateKeyDerBytes: x509.MarshalPKCS1PrivateKey(privateKey), publicKey: sshPublicKey, }, nil diff --git a/builder/virtualbox/common/sshkeypair_test.go b/builder/virtualbox/common/sshkeypair_test.go index e976b655e..a9f7c8de7 100644 --- a/builder/virtualbox/common/sshkeypair_test.go +++ b/builder/virtualbox/common/sshkeypair_test.go @@ -1,11 +1,13 @@ package common import ( + "bytes" "crypto/rand" "errors" "strconv" "testing" + "github.com/hashicorp/packer/common/uuid" "golang.org/x/crypto/ssh" ) @@ -14,6 +16,8 @@ type expected struct { kind sshKeyPairType bits int desc string + name string + data []byte } func (o expected) matches(kp sshKeyPair) error { @@ -29,28 +33,37 @@ func (o expected) matches(kp sshKeyPair) error { return errors.New("expected description's value cannot be empty") } + if len(o.data) == 0 { + return errors.New("expected random data value cannot be nothing") + } + if kp.Type() != o.kind { - return errors.New("expected key pair type to be " + - o.kind.String() + " - got '" + kp.Type().String() + "'") + return errors.New("key pair type should be " + o.kind.String() + + " - got '" + kp.Type().String() + "'") } if kp.Bits() != o.bits { - return errors.New("expected key pair to be " + - strconv.Itoa(o.bits) + " bits - got " + strconv.Itoa(kp.Bits())) + return errors.New("key pair bits should be " + strconv.Itoa(o.bits) + + " - got " + strconv.Itoa(kp.Bits())) + } + + if len(o.name) > 0 && kp.Name() != o.name { + return errors.New("key pair name should be '" + o.name + + "' - got '" + kp.Name() + "'") } expDescription := kp.Type().String() + " " + strconv.Itoa(o.bits) if kp.Description() != expDescription { - return errors.New("expected key pair description to be '" + + return errors.New("key pair description should be '" + expDescription + "' - got '" + kp.Description() + "'") } - err := verifyPublickeyAuthorizedKeysFormat(kp) + err := o.verifyPublicKeyAuthorizedKeysFormat(kp) if err != nil { return err } - err = verifySshKeyPair(kp) + err = o.verifySshKeyPair(kp) if err != nil { return err } @@ -58,76 +71,7 @@ func (o expected) matches(kp sshKeyPair) error { return nil } -func TestDefaultSshKeyPairBuilder_Build_Default(t *testing.T) { - kp, err := newSshKeyPairBuilder().Build() - if err != nil { - t.Fatal(err.Error()) - } - - err = expected{ - kind: ecdsaSsh, - bits: 521, - desc: "ecdsa 521", - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) - } -} - -func TestDefaultSshKeyPairBuilder_Build_EcdsaDefault(t *testing.T) { - kp, err := newSshKeyPairBuilder().SetType(ecdsaSsh).Build() - if err != nil { - t.Fatal(err.Error()) - } - - err = expected{ - kind: ecdsaSsh, - bits: 521, - desc: "ecdsa 521", - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) - } -} - -func TestDefaultSshKeyPairBuilder_Build_RsaDefault(t *testing.T) { - kp, err := newSshKeyPairBuilder().SetType(rsaSsh).Build() - if err != nil { - t.Fatal(err.Error()) - } - - err = expected{ - kind: rsaSsh, - bits: 4096, - desc: "rsa 4096", - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) - } -} - -func verifySshKeyPair(kp sshKeyPair) error { - signer, err := ssh.ParsePrivateKey(kp.PrivateKeyPemBlock()) - if err != nil { - return errors.New("failed to parse private key during verification - " + err.Error()) - } - - data := []byte{'b', 'r', '4', 'n', '3'} - - signature, err := signer.Sign(rand.Reader, data) - if err != nil { - return errors.New("failed to sign test data during verification - " + err.Error()) - } - - err = signer.PublicKey().Verify(data, signature) - if err != nil { - return errors.New("failed to verify test data - " + err.Error()) - } - - return nil -} - -func verifyPublickeyAuthorizedKeysFormat(kp sshKeyPair) error { +func (o expected) verifyPublicKeyAuthorizedKeysFormat(kp sshKeyPair) error { newLines := []newLineOption{ unixNewLine, noNewLine, @@ -158,7 +102,141 @@ func verifyPublickeyAuthorizedKeysFormat(kp sshKeyPair) error { return errors.New("public key in authorized keys format does not have windows new line when windows was specified") } } + + if len(o.name) > 0 { + if len(publicKeyAk) < len(o.name) { + return errors.New("public key in authorized keys format is shorter than the key pair's name") + } + + suffix := []byte{' '} + suffix = append(suffix, o.name...) + suffix = append(suffix, nl.Bytes()...) + if !bytes.HasSuffix(publicKeyAk, suffix) { + return errors.New("public key in authorized keys format with name does not have name in suffix - got '" + + string(publicKeyAk) + "'") + } + } } return nil } + +func (o expected) verifySshKeyPair(kp sshKeyPair) error { + signer, err := ssh.ParsePrivateKey(kp.PrivateKeyPemBlock()) + if err != nil { + return errors.New("failed to parse private key during verification - " + err.Error()) + } + + signature, err := signer.Sign(rand.Reader, o.data) + if err != nil { + return errors.New("failed to sign test data during verification - " + err.Error()) + } + + err = signer.PublicKey().Verify(o.data, signature) + if err != nil { + return errors.New("failed to verify test data - " + err.Error()) + } + + return nil +} + +func TestDefaultSshKeyPairBuilder_Build_Default(t *testing.T) { + kp, err := newSshKeyPairBuilder().Build() + if err != nil { + t.Fatal(err.Error()) + } + + err = expected{ + kind: ecdsaSsh, + bits: 521, + desc: "ecdsa 521", + data: []byte(uuid.TimeOrderedUUID()), + }.matches(kp) + if err != nil { + t.Fatal(err.Error()) + } +} + +func TestDefaultSshKeyPairBuilder_Build_EcdsaDefault(t *testing.T) { + kp, err := newSshKeyPairBuilder(). + SetType(ecdsaSsh). + Build() + if err != nil { + t.Fatal(err.Error()) + } + + err = expected{ + kind: ecdsaSsh, + bits: 521, + desc: "ecdsa 521", + data: []byte(uuid.TimeOrderedUUID()), + }.matches(kp) + if err != nil { + t.Fatal(err.Error()) + } +} + +func TestDefaultSshKeyPairBuilder_Build_RsaDefault(t *testing.T) { + kp, err := newSshKeyPairBuilder(). + SetType(rsaSsh). + Build() + if err != nil { + t.Fatal(err.Error()) + } + + err = expected{ + kind: rsaSsh, + bits: 4096, + desc: "rsa 4096", + data: []byte(uuid.TimeOrderedUUID()), + }.matches(kp) + if err != nil { + t.Fatal(err.Error()) + } +} + +func TestDefaultSshKeyPairBuilder_Build_NamedEcdsa(t *testing.T) { + name := uuid.TimeOrderedUUID() + + kp, err := newSshKeyPairBuilder(). + SetType(ecdsaSsh). + SetName(name). + Build() + if err != nil { + t.Fatal(err.Error()) + } + + err = expected{ + kind: ecdsaSsh, + bits: 521, + desc: "ecdsa 521", + data: []byte(uuid.TimeOrderedUUID()), + name: name, + }.matches(kp) + if err != nil { + t.Fatal(err.Error()) + } +} + +func TestDefaultSshKeyPairBuilder_Build_NamedRsa(t *testing.T) { + name := uuid.TimeOrderedUUID() + + kp, err := newSshKeyPairBuilder(). + SetType(rsaSsh). + SetName(name). + Build() + if err != nil { + t.Fatal(err.Error()) + } + + err = expected{ + kind: rsaSsh, + bits: 4096, + desc: "rsa 4096", + data: []byte(uuid.TimeOrderedUUID()), + name: name, + }.matches(kp) + if err != nil { + t.Fatal(err.Error()) + } +}