Merge pull request #4 from stephen-fox/refactor-ssh-key-pair-logic

Initial take on code review feedback from @azr.
This commit is contained in:
Chris Marget 2019-02-27 17:05:44 -05:00 committed by GitHub
commit 0f1bde760c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 383 additions and 744 deletions

View File

@ -36,19 +36,19 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis
return multistep.ActionHalt return multistep.ActionHalt
} }
kp, err := ssh.NewKeyPairBuilder(). kp, err := ssh.KeyPairFromPrivateKey(ssh.FromPrivateKeyConfig{
SetPrivateKey(privateKeyBytes). RawPrivateKeyPemBlock: privateKeyBytes,
SetName(fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID())). Name: fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID()),
Build() })
if err != nil { if err != nil {
state.Put("error", err) state.Put("error", err)
return multistep.ActionHalt return multistep.ActionHalt
} }
s.Comm.SSHPrivateKey = privateKeyBytes s.Comm.SSHPrivateKey = privateKeyBytes
s.Comm.SSHKeyPairName = kp.Name() s.Comm.SSHKeyPairName = kp.Name
s.Comm.SSHTemporaryKeyPairName = kp.Name() s.Comm.SSHTemporaryKeyPairName = kp.Name
s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine(ssh.UnixNewLine) s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine
return multistep.ActionContinue return multistep.ActionContinue
} }
@ -60,21 +60,21 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis
ui.Say("Creating ephemeral key pair for SSH communicator...") ui.Say("Creating ephemeral key pair for SSH communicator...")
kp, err := ssh.NewKeyPairBuilder(). kp, err := ssh.NewKeyPair(ssh.CreateKeyPairConfig{
SetName(fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID())). Name: fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID()),
Build() })
if err != nil { if err != nil {
state.Put("error", fmt.Errorf("Error creating temporary keypair: %s", err)) state.Put("error", fmt.Errorf("Error creating temporary keypair: %s", err))
return multistep.ActionHalt return multistep.ActionHalt
} }
s.Comm.SSHKeyPairName = kp.Name() s.Comm.SSHKeyPairName = kp.Name
s.Comm.SSHTemporaryKeyPairName = kp.Name() s.Comm.SSHTemporaryKeyPairName = kp.Name
s.Comm.SSHPrivateKey = kp.PrivateKeyPemBlock() s.Comm.SSHPrivateKey = kp.PrivateKeyPemBlock
s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine(ssh.UnixNewLine) s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine
s.Comm.SSHClearAuthorizedKeys = true s.Comm.SSHClearAuthorizedKeys = true
ui.Say(fmt.Sprintf("Created ephemeral SSH key pair of type %s", kp.Description())) ui.Say("Created ephemeral SSH key pair for communicator")
// If we're in debug mode, output the private key to the working // If we're in debug mode, output the private key to the working
// directory. // directory.
@ -90,7 +90,7 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis
defer f.Close() defer f.Close()
// Write the key out // Write the key out
if _, err := f.Write(kp.PrivateKeyPemBlock()); err != nil { if _, err := f.Write(kp.PrivateKeyPemBlock); err != nil {
state.Put("error", fmt.Errorf("Error saving debug key: %s", err)) state.Put("error", fmt.Errorf("Error saving debug key: %s", err))
return multistep.ActionHalt return multistep.ActionHalt
} }

View File

@ -1,38 +0,0 @@
package ssh
import (
"crypto/dsa"
gossh "golang.org/x/crypto/ssh"
)
type dsaKeyPair struct {
privateKey *dsa.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o dsaKeyPair) Type() KeyPairType {
return Dsa
}
func (o dsaKeyPair) Bits() int {
return 1024
}
func (o dsaKeyPair) Name() string {
return o.name
}
func (o dsaKeyPair) Description() string {
return description(o)
}
func (o dsaKeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o dsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"crypto/ecdsa"
gossh "golang.org/x/crypto/ssh"
)
type ecdsaKeyPair struct {
privateKey *ecdsa.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o ecdsaKeyPair) Type() KeyPairType {
return Ecdsa
}
func (o ecdsaKeyPair) Bits() int {
return o.privateKey.Curve.Params().BitSize
}
func (o ecdsaKeyPair) Name() string {
return o.name
}
func (o ecdsaKeyPair) Description() string {
return description(o)
}
func (o ecdsaKeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o ecdsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"golang.org/x/crypto/ed25519"
gossh "golang.org/x/crypto/ssh"
)
type ed25519KeyPair struct {
privateKey *ed25519.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o ed25519KeyPair) Type() KeyPairType {
return Ed25519
}
func (o ed25519KeyPair) Bits() int {
return 256
}
func (o ed25519KeyPair) Name() string {
return o.name
}
func (o ed25519KeyPair) Description() string {
return description(o)
}
func (o ed25519KeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o ed25519KeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}

View File

@ -9,9 +9,7 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -38,175 +36,130 @@ func (o KeyPairType) String() string {
return string(o) return string(o)
} }
const ( // CreateKeyPairConfig describes how an SSH key pair should be created.
// UnixNewLine is a unix new line. type CreateKeyPairConfig struct {
UnixNewLine NewLineOption = "\n" // Type describes the key pair's type.
Type KeyPairType
// WindowsNewLine is a Windows new line. // Bits represents the key pair's bits of entropy. E.g., 4096 for
WindowsNewLine NewLineOption = "\r\n" // a 4096 bit RSA key pair, or 521 for a ECDSA key pair with a
// 521-bit curve.
Bits int
// NoNewLine will not append a new line. // Name is the resulting key pair's name. This is used to identify
NoNewLine NewLineOption = "" // the key pair in the SSH server's 'authorized_keys'.
) Name string
// NewLineOption specifies the type of new line to append to a string.
// See the 'const' block for choices.
type NewLineOption string
func (o NewLineOption) String() string {
return string(o)
} }
func (o NewLineOption) Bytes() []byte { // FromPrivateKeyConfig describes how an SSH key pair should be loaded from an
return []byte(o) // existing private key.
type FromPrivateKeyConfig struct {
// RawPrivateKeyPemBlock is the raw private key that the key pair
// should be loaded from.
RawPrivateKeyPemBlock []byte
// Name is the resulting key pair's name. This is used to identify
// the key pair in the SSH server's 'authorized_keys'.
Name string
} }
// KeyPairBuilder builds SSH key pairs. // KeyPair represents an SSH key pair.
// It can generate new keys of type RSA and ECDSA. // TODO: Maybe a field for a description? Maybe save the type?
// It can parse user supplied keys of type DSA, RSA, ECDSA, type KeyPair struct {
// and ED25519. // PrivateKeyPemBlock represents the key pair's private key in
type KeyPairBuilder interface { // ASN.1 Distinguished Encoding Rules (DER) format in a
// SetType sets the key pair type. // Privacy-Enhanced Mail (PEM) block.
SetType(KeyPairType) KeyPairBuilder PrivateKeyPemBlock []byte
// SetBits sets the key pair's bits of entropy. // PublicKeyAuthorizedKeysLine represents the key pair's public key
SetBits(int) KeyPairBuilder // as a line in OpenSSH authorized_keys.
PublicKeyAuthorizedKeysLine []byte
// SetName sets the name of the key pair. This is primarily // Name is the key pair's name. This is used to identify
// used to identify the public key in the authorized_keys file. // the key pair in the SSH server's 'authorized_keys'.
SetName(string) KeyPairBuilder Name string
// SetPrivateKey takes an existing private key in PEM format.
// It overrides key generation details specified by SetType()
// and SetBits().
SetPrivateKey([]byte) KeyPairBuilder
// Build returns a SSH key pair.
//
// The following defaults are used if not specified:
// Default type: ECDSA
// Default bits of entropy:
// - RSA: 4096
// - ECDSA: 521
// Default name: (empty string)
Build() (KeyPair, error)
} }
type defaultKeyPairBuilder struct { // KeyPairFromPrivateKey returns a KeyPair loaded from an existing private key.
// kind describes the resulting key pair's type. //
kind KeyPairType // Supported key pair types include:
// - DSA
// bits is the resulting key pair's bits of entropy. // - ECDSA
bits int // - ED25519
// - RSA
// name is the resulting key pair's name. func KeyPairFromPrivateKey(config FromPrivateKeyConfig) (KeyPair, error) {
name string privateKey, err := gossh.ParseRawPrivateKey(config.RawPrivateKeyPemBlock)
// privatePemBytes is the supplied key data when the builder
// is working from a preallocated key.
privatePemBytes []byte
}
func (o *defaultKeyPairBuilder) SetType(kind KeyPairType) KeyPairBuilder {
o.kind = kind
return o
}
func (o *defaultKeyPairBuilder) SetBits(bits int) KeyPairBuilder {
o.bits = bits
return o
}
func (o *defaultKeyPairBuilder) SetName(name string) KeyPairBuilder {
o.name = name
return o
}
func (o *defaultKeyPairBuilder) SetPrivateKey(privateBytes []byte) KeyPairBuilder {
o.privatePemBytes = privateBytes
return o
}
func (o *defaultKeyPairBuilder) Build() (KeyPair, error) {
if o.privatePemBytes != nil {
return o.preallocatedKeyPair()
}
switch o.kind {
case Rsa:
return o.newRsaKeyPair()
case Ecdsa, Default:
return o.newEcdsaKeyPair()
}
return nil, fmt.Errorf("Cannot generate SSH key pair - unsupported key pair type: %s", o.kind.String())
}
// preallocatedKeyPair returns an SSH key pair based on user
// supplied PEM data.
func (o *defaultKeyPairBuilder) preallocatedKeyPair() (KeyPair, error) {
privateKey, err := gossh.ParseRawPrivateKey(o.privatePemBytes)
if err != nil { if err != nil {
return nil, err return KeyPair{}, err
} }
switch pk := privateKey.(type) { switch pk := privateKey.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
publicKey, err := gossh.NewPublicKey(&pk.PublicKey) publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil { if err != nil {
return nil, err return KeyPair{}, err
} }
return &rsaKeyPair{ return KeyPair{
privateKey: pk, PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
publicKey: publicKey, PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
name: o.name,
privatePemBlock: o.privatePemBytes,
}, nil }, nil
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
publicKey, err := gossh.NewPublicKey(&pk.PublicKey) publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil { if err != nil {
return nil, err return KeyPair{}, err
} }
return &ecdsaKeyPair{ return KeyPair{
privateKey: pk, PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
publicKey: publicKey, PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
name: o.name,
privatePemBlock: o.privatePemBytes,
}, nil }, nil
case *dsa.PrivateKey: case *dsa.PrivateKey:
publicKey, err := gossh.NewPublicKey(&pk.PublicKey) publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil { if err != nil {
return nil, err return KeyPair{}, err
} }
return &dsaKeyPair{ return KeyPair{
privateKey: pk, PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
publicKey: publicKey, PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
name: o.name,
privatePemBlock: o.privatePemBytes,
}, nil }, nil
case *ed25519.PrivateKey: case *ed25519.PrivateKey:
publicKey, err := gossh.NewPublicKey(pk.Public()) publicKey, err := gossh.NewPublicKey(pk.Public())
if err != nil { if err != nil {
return nil, err return KeyPair{}, err
} }
return &ed25519KeyPair{ return KeyPair{
privateKey: pk, PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
publicKey: publicKey, PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
name: o.name,
privatePemBlock: o.privatePemBytes,
}, nil }, nil
} }
return nil, fmt.Errorf("Cannot parse preallocated key pair - unknown ssh key pair type") return KeyPair{}, fmt.Errorf("Cannot parse existing SSH key pair - unknown key pair type")
}
// NewKeyPair generates a new SSH key pair using the specified
// CreateKeyPairConfig.
func NewKeyPair(config CreateKeyPairConfig) (KeyPair, error) {
if config.Type == Default {
config.Type = Ecdsa
}
switch config.Type {
case Ecdsa:
return newEcdsaKeyPair(config)
case Rsa:
return newRsaKeyPair(config)
}
return KeyPair{}, fmt.Errorf("Unable to generate new key pair, type %s is not supported",
config.Type.String())
} }
// newEcdsaKeyPair returns a new ECDSA SSH key pair. // newEcdsaKeyPair returns a new ECDSA SSH key pair.
func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) { func newEcdsaKeyPair(config CreateKeyPairConfig) (KeyPair, error) {
var curve elliptic.Curve var curve elliptic.Curve
switch o.bits { switch config.Bits {
case 0: case 0:
o.bits = 521 config.Bits = 521
fallthrough fallthrough
case 521: case 521:
curve = elliptic.P521() curve = elliptic.P521()
@ -216,26 +169,24 @@ func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) {
curve = elliptic.P256() curve = elliptic.P256()
case 224: case 224:
// Not supported by "golang.org/x/crypto/ssh". // Not supported by "golang.org/x/crypto/ssh".
return &ecdsaKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " + return KeyPair{}, fmt.Errorf("golang.org/x/crypto/ssh does not support %d bits", config.Bits)
strconv.Itoa(o.bits) + " bits")
default: default:
return &ecdsaKeyPair{}, errors.New("crypto/elliptic does not support " + return KeyPair{}, fmt.Errorf("crypto/elliptic does not support %d bits", config.Bits)
strconv.Itoa(o.bits) + " bits")
} }
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader) privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil { if err != nil {
return &ecdsaKeyPair{}, err return KeyPair{}, err
} }
sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey) sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey)
if err != nil { if err != nil {
return &ecdsaKeyPair{}, err return KeyPair{}, err
} }
privateRaw, err := x509.MarshalECPrivateKey(privateKey) privateRaw, err := x509.MarshalECPrivateKey(privateKey)
if err != nil { if err != nil {
return &ecdsaKeyPair{}, err return KeyPair{}, err
} }
privatePem, err := rawPemBlock(&pem.Block{ privatePem, err := rawPemBlock(&pem.Block{
@ -244,31 +195,30 @@ func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) {
Bytes: privateRaw, Bytes: privateRaw,
}) })
if err != nil { if err != nil {
return &ecdsaKeyPair{}, err return KeyPair{}, err
} }
return &ecdsaKeyPair{ return KeyPair{
privateKey: privateKey, PrivateKeyPemBlock: privatePem,
publicKey: sshPublicKey, PublicKeyAuthorizedKeysLine: authorizedKeysLine(sshPublicKey, config.Name),
name: o.name, Name: config.Name,
privatePemBlock: privatePem,
}, nil }, nil
} }
// newRsaKeyPair returns a new RSA SSH key pair. // newRsaKeyPair returns a new RSA SSH key pair.
func (o *defaultKeyPairBuilder) newRsaKeyPair() (KeyPair, error) { func newRsaKeyPair(config CreateKeyPairConfig) (KeyPair, error) {
if o.bits == 0 { if config.Bits == 0 {
o.bits = defaultRsaBits config.Bits = defaultRsaBits
} }
privateKey, err := rsa.GenerateKey(rand.Reader, o.bits) privateKey, err := rsa.GenerateKey(rand.Reader, config.Bits)
if err != nil { if err != nil {
return &rsaKeyPair{}, err return KeyPair{}, err
} }
sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey) sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey)
if err != nil { if err != nil {
return &rsaKeyPair{}, err return KeyPair{}, err
} }
privatePemBlock, err := rawPemBlock(&pem.Block{ privatePemBlock, err := rawPemBlock(&pem.Block{
@ -277,48 +227,16 @@ func (o *defaultKeyPairBuilder) newRsaKeyPair() (KeyPair, error) {
Bytes: x509.MarshalPKCS1PrivateKey(privateKey), Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}) })
if err != nil { if err != nil {
return &rsaKeyPair{}, err return KeyPair{}, err
} }
return &rsaKeyPair{ return KeyPair{
privateKey: privateKey, PrivateKeyPemBlock: privatePemBlock,
publicKey: sshPublicKey, PublicKeyAuthorizedKeysLine: authorizedKeysLine(sshPublicKey, config.Name),
name: o.name, Name: config.Name,
privatePemBlock: privatePemBlock,
}, nil }, nil
} }
// KeyPair represents a SSH key pair.
type KeyPair interface {
// Type returns the key pair's type.
Type() KeyPairType
// Bits returns the bits of entropy.
Bits() int
// Name returns the key pair's name. An empty string is
// returned if no name was specified.
Name() string
// Description returns a brief description of the key pair that
// is suitable for log messages or printing.
Description() string
// PrivateKeyPemBlock returns a slice of bytes representing
// the private key in ASN.1 Distinguished Encoding Rules (DER)
// format in a Privacy-Enhanced Mail (PEM) block.
PrivateKeyPemBlock() []byte
// PublicKeyAuthorizedKeysLine returns a slice of bytes
// representing the public key as a line in OpenSSH authorized_keys
// format with the specified new line.
PublicKeyAuthorizedKeysLine(NewLineOption) []byte
}
func NewKeyPairBuilder() KeyPairBuilder {
return &defaultKeyPairBuilder{}
}
// rawPemBlock encodes a pem.Block to a slice of bytes. // rawPemBlock encodes a pem.Block to a slice of bytes.
func rawPemBlock(block *pem.Block) ([]byte, error) { func rawPemBlock(block *pem.Block) ([]byte, error) {
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)
@ -331,26 +249,10 @@ func rawPemBlock(block *pem.Block) ([]byte, error) {
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
// description returns a string describing a key pair. // authorizedKeysLine returns a slice of bytes representing an SSH public key
func description(kp KeyPair) string { // as a line in OpenSSH authorized_keys format. No line break is appended.
buffer := bytes.NewBuffer(nil) func authorizedKeysLine(sshPublicKey gossh.PublicKey, name string) []byte {
result := gossh.MarshalAuthorizedKey(sshPublicKey)
buffer.WriteString(strconv.Itoa(kp.Bits()))
buffer.WriteString(" bit ")
buffer.WriteString(kp.Type().String())
if len(kp.Name()) > 0 {
buffer.WriteString(" named ")
buffer.WriteString(kp.Name())
}
return buffer.String()
}
// publicKeyAuthorizedKeysLine returns a slice of bytes representing a SSH
// public key as a line in OpenSSH authorized_keys format.
func publicKeyAuthorizedKeysLine(publicKey gossh.PublicKey, name string, nl NewLineOption) []byte {
result := gossh.MarshalAuthorizedKey(publicKey)
// Remove the mandatory unix new line. // Remove the mandatory unix new line.
// Awful, but the go ssh library automatically appends // Awful, but the go ssh library automatically appends
@ -362,14 +264,5 @@ func publicKeyAuthorizedKeysLine(publicKey gossh.PublicKey, name string, nl NewL
result = append(result, name...) result = append(result, name...)
} }
switch nl {
case WindowsNewLine:
result = append(result, nl.Bytes()...)
case UnixNewLine:
// This is how all the other "SSH key pair" code works in
// the different builders.
result = append(result, UnixNewLine.Bytes()...)
}
return result return result
} }

View File

@ -2,24 +2,17 @@ package ssh
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/dsa"
"errors" "crypto/ecdsa"
"strconv" "crypto/rsa"
"fmt"
"testing" "testing"
"github.com/hashicorp/packer/common/uuid" "github.com/hashicorp/packer/common/uuid"
"golang.org/x/crypto/ed25519"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
) )
// expected contains the data that the key pair should contain.
type expected struct {
kind KeyPairType
bits int
desc string
name string
data []byte
}
const ( const (
pemRsa1024 = `-----BEGIN RSA PRIVATE KEY----- pemRsa1024 = `-----BEGIN RSA PRIVATE KEY-----
MIICWwIBAAKBgQDJEMFPpTBiWNDb3qEIPTSeEnIP8FZdBpG8njOrclcMoQQNhzZ+ MIICWwIBAAKBgQDJEMFPpTBiWNDb3qEIPTSeEnIP8FZdBpG8njOrclcMoQQNhzZ+
@ -151,401 +144,306 @@ QBAgM=
` `
) )
func (o expected) matches(kp KeyPair) error { func TestNewKeyPair_Default(t *testing.T) {
if o.kind.String() == "" { kp, err := NewKeyPair(CreateKeyPairConfig{})
return errors.New("expected kind's value cannot be empty")
}
if o.bits <= 0 {
return errors.New("expected bits' value cannot be less than or equal to 0")
}
if o.desc == "" {
return errors.New("expected description's value cannot be empty")
}
if len(o.data) == 0 {
return errors.New("expected random data value cannot be nothing")
}
if kp.Type() != o.kind {
return errors.New("key pair type should be " + o.kind.String() +
" - got '" + kp.Type().String() + "'")
}
if kp.Bits() != o.bits {
return errors.New("key pair bits should be " + strconv.Itoa(o.bits) +
" - got " + strconv.Itoa(kp.Bits()))
}
if len(o.name) > 0 && kp.Name() != o.name {
return errors.New("key pair name should be '" + o.name +
"' - got '" + kp.Name() + "'")
}
if kp.Description() != o.desc {
return errors.New("key pair description should be '" +
o.desc + "' - got '" + kp.Description() + "'")
}
err := o.verifyPublicKeyAuthorizedKeysFormat(kp)
if err != nil {
return err
}
err = o.verifyKeyPair(kp)
if err != nil {
return err
}
return nil
}
func (o expected) verifyPublicKeyAuthorizedKeysFormat(kp KeyPair) error {
newLines := []NewLineOption{
UnixNewLine,
NoNewLine,
WindowsNewLine,
}
for _, nl := range newLines {
publicKeyAk := kp.PublicKeyAuthorizedKeysLine(nl)
if len(publicKeyAk) < 2 {
return errors.New("expected public key in authorized keys format to be at least 2 bytes")
}
switch nl {
case NoNewLine:
if publicKeyAk[len(publicKeyAk)-1] == '\n' {
return errors.New("public key in authorized keys format has trailing new line when none was specified")
}
case UnixNewLine:
if publicKeyAk[len(publicKeyAk)-1] != '\n' {
return errors.New("public key in authorized keys format does not have unix new line when unix was specified")
}
if string(publicKeyAk[len(publicKeyAk)-2:]) == WindowsNewLine.String() {
return errors.New("public key in authorized keys format has windows new line when unix was specified")
}
case WindowsNewLine:
if string(publicKeyAk[len(publicKeyAk)-2:]) != WindowsNewLine.String() {
return errors.New("public key in authorized keys format does not have windows new line when windows was specified")
}
}
if len(o.name) > 0 {
if len(publicKeyAk) < len(o.name) {
return errors.New("public key in authorized keys format is shorter than the key pair's name")
}
suffix := []byte{' '}
suffix = append(suffix, o.name...)
suffix = append(suffix, nl.Bytes()...)
if !bytes.HasSuffix(publicKeyAk, suffix) {
return errors.New("public key in authorized keys format with name does not have name in suffix - got '" +
string(publicKeyAk) + "'")
}
}
}
return nil
}
func (o expected) verifyKeyPair(kp KeyPair) error {
signer, err := gossh.ParsePrivateKey(kp.PrivateKeyPemBlock())
if err != nil {
return errors.New("failed to parse private key during verification - " + err.Error())
}
signature, err := signer.Sign(rand.Reader, o.data)
if err != nil {
return errors.New("failed to sign test data during verification - " + err.Error())
}
err = signer.PublicKey().Verify(o.data, signature)
if err != nil {
return errors.New("failed to verify test data - " + err.Error())
}
return nil
}
func TestDefaultKeyPairBuilder_Build_Default(t *testing.T) {
kp, err := NewKeyPairBuilder().Build()
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
err = expected{ err = verifyEcdsaKeyPair(kp, expectedData{
kind: Ecdsa,
bits: 521, bits: 521,
desc: "521 bit ECDSA", })
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
} }
func TestDefaultKeyPairBuilder_Build_EcdsaDefault(t *testing.T) { func TestNewKeyPair_ECDSA_Default(t *testing.T) {
kp, err := NewKeyPairBuilder(). kp, err := NewKeyPair(CreateKeyPairConfig{
SetType(Ecdsa). Type: Ecdsa,
Build() })
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
err = expected{ err = verifyEcdsaKeyPair(kp, expectedData{
kind: Ecdsa,
bits: 521, bits: 521,
desc: "521 bit ECDSA", })
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
} }
func TestDefaultKeyPairBuilder_Build_EcdsaSupportedCurves(t *testing.T) { func TestNewKeyPair_ECDSA_Positive(t *testing.T) {
supportedBits := []int{ for _, bits := range []int{521, 384, 256} {
521, config := CreateKeyPairConfig{
384, Type: Ecdsa,
256, Bits: bits,
} Name: uuid.TimeOrderedUUID(),
}
for _, bits := range supportedBits { kp, err := NewKeyPair(config)
kp, err := NewKeyPairBuilder().
SetType(Ecdsa).
SetBits(bits).
Build()
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
err = expected{ err = verifyEcdsaKeyPair(kp, expectedData{
kind: Ecdsa,
bits: bits, bits: bits,
desc: strconv.Itoa(bits) + " bit ECDSA", name: config.Name,
data: []byte(uuid.TimeOrderedUUID()), })
}.matches(kp)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
} }
} }
func TestDefaultKeyPairBuilder_Build_RsaDefault(t *testing.T) { func TestNewKeyPair_ECDSA_Negative(t *testing.T) {
kp, err := NewKeyPairBuilder(). for _, bits := range []int{224, 1, 2, 3} {
SetType(Rsa). _, err := NewKeyPair(CreateKeyPairConfig{
Build() Type: Ecdsa,
if err != nil { Bits: bits,
t.Fatal(err.Error()) })
} if err == nil {
t.Fatalf("expected key pair generation to fail for %d bits", bits)
err = expected{ }
kind: Rsa,
bits: 4096,
desc: "4096 bit RSA",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
} }
} }
func TestDefaultKeyPairBuilder_Build_NamedEcdsa(t *testing.T) { func TestNewKeyPair_RSA_Positive(t *testing.T) {
name := uuid.TimeOrderedUUID() for _, bits := range []int{4096, 2048} {
config := CreateKeyPairConfig{
Type: Rsa,
Bits: bits,
Name: uuid.TimeOrderedUUID(),
}
kp, err := NewKeyPairBuilder(). kp, err := NewKeyPair(config)
SetType(Ecdsa).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Ecdsa,
bits: 521,
desc: "521 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultKeyPairBuilder_Build_NamedRsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
kp, err := NewKeyPairBuilder().
SetType(Rsa).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Rsa,
bits: 4096,
desc: "4096 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultKeyPairBuilder_SetPrivateKey(t *testing.T) {
name := uuid.TimeOrderedUUID()
pemData := make(map[string]expected)
pemData[pemRsa1024] = expected{
bits: 1024,
kind: Rsa,
name: name,
desc: "1024 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemRsa2048] = expected{
bits: 2048,
kind: Rsa,
name: name,
desc: "2048 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshRsa1024] = expected{
bits: 1024,
kind: Rsa,
name: name,
desc: "1024 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshRsa2048] = expected{
bits: 2048,
kind: Rsa,
name: name,
desc: "2048 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemDsa] = expected{
bits: 1024,
kind: Dsa,
name: name,
desc: "1024 bit DSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemEcdsa384] = expected{
bits: 384,
kind: Ecdsa,
name: name,
desc: "384 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemEcdsa521] = expected{
bits: 521,
kind: Ecdsa,
name: name,
desc: "521 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshEd25519] = expected{
bits: 256,
kind: Ed25519,
name: name,
desc: "256 bit ED25519 named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
for s, l := range pemData {
kp, err := NewKeyPairBuilder().SetPrivateKey([]byte(s)).SetName(name).Build()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err.Error())
} }
err = l.matches(kp)
err = verifyRsaKeyPair(kp, expectedData{
bits: config.Bits,
name: config.Name,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err.Error())
} }
} }
} }
func TestDefaultKeyPairBuilder_SetPrivateKey_Override(t *testing.T) { func TestKeyPairFromPrivateKey(t *testing.T) {
name := uuid.TimeOrderedUUID() m := map[string]fromPrivateExpectedData{
pemData := make(map[string]expected) pemRsa1024: {
pemData[pemRsa1024] = expected{ t: Rsa,
bits: 1024, d: expectedData{
kind: Rsa, bits: 1024,
name: name, },
desc: "1024 bit RSA named " + name, },
data: []byte(uuid.TimeOrderedUUID()), pemRsa2048: {
} t: Rsa,
pemData[pemRsa2048] = expected{ d: expectedData{
bits: 2048, bits: 2048,
kind: Rsa, },
name: name, },
desc: "2048 bit RSA named " + name, pemOpenSshRsa1024: {
data: []byte(uuid.TimeOrderedUUID()), t: Rsa,
} d: expectedData{
pemData[pemOpenSshRsa1024] = expected{ bits: 1024,
bits: 1024, },
kind: Rsa, },
name: name, pemOpenSshRsa2048: {
desc: "1024 bit RSA named " + name, t: Rsa,
data: []byte(uuid.TimeOrderedUUID()), d: expectedData{
} bits: 2048,
pemData[pemOpenSshRsa2048] = expected{ },
bits: 2048, },
kind: Rsa, pemDsa: {
name: name, t: Dsa,
desc: "2048 bit RSA named " + name, d: expectedData{
data: []byte(uuid.TimeOrderedUUID()), bits: 1024,
} },
pemData[pemDsa] = expected{ },
bits: 1024, pemEcdsa384: {
kind: Dsa, t: Ecdsa,
name: name, d: expectedData{
desc: "1024 bit DSA named " + name, bits: 384,
data: []byte(uuid.TimeOrderedUUID()), },
} },
pemData[pemEcdsa384] = expected{ pemEcdsa521: {
bits: 384, t: Ecdsa,
kind: Ecdsa, d: expectedData{
name: name, bits: 521,
desc: "384 bit ECDSA named " + name, },
data: []byte(uuid.TimeOrderedUUID()), },
} pemOpenSshEd25519: {
pemData[pemEcdsa521] = expected{ t: Ed25519,
bits: 521, d: expectedData{
kind: Ecdsa, bits: 256,
name: name, },
desc: "521 bit ECDSA named " + name, },
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshEd25519] = expected{
bits: 256,
kind: Ed25519,
name: name,
desc: "256 bit ED25519 named " + name,
data: []byte(uuid.TimeOrderedUUID()),
} }
supportedKeyTypes := []KeyPairType{Rsa, Dsa} for rawPrivateKey, expected := range m {
for _, keyType := range supportedKeyTypes { kp, err := KeyPairFromPrivateKey(FromPrivateKeyConfig{
for pemString, expectedResult := range pemData { RawPrivateKeyPemBlock: []byte(rawPrivateKey),
kp, err := NewKeyPairBuilder(). })
SetPrivateKey([]byte(pemString)). if err != nil {
SetName(name). t.Fatal(err.Error())
SetType(keyType). }
Build()
if err != nil { switch expected.t {
t.Fatal(err) case Dsa:
} err = verifyDsaKeyPair(kp, expected)
err = expectedResult.matches(kp) case Ecdsa:
if err != nil { err = verifyEcdsaKeyPair(kp, expected.d)
t.Fatal(err) case Ed25519:
} err = verifyEd25519KeyPair(kp, expected)
case Rsa:
err = verifyRsaKeyPair(kp, expected.d)
default:
err = fmt.Errorf("unexected SSH key pair type %s", expected.t.String())
}
if err != nil {
t.Fatal(err.Error())
} }
} }
} }
type fromPrivateExpectedData struct {
t KeyPairType
d expectedData
}
type expectedData struct {
bits int
name string
}
func verifyEcdsaKeyPair(kp KeyPair, e expectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*ecdsa.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *ecdsa.PrivateKey")
}
if pk.Curve.Params().BitSize != e.bits {
return fmt.Errorf("bit size should be %d - got %d", e.bits, pk.Curve.Params().BitSize)
}
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}
func verifyRsaKeyPair(kp KeyPair, e expectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *rsa.PrivateKey")
}
if pk.N.BitLen() != e.bits {
return fmt.Errorf("bit size should be %d - got %d", e.bits, pk.N.BitLen())
}
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}
func verifyDsaKeyPair(kp KeyPair, e fromPrivateExpectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*dsa.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *rsa.PrivateKey")
}
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.d.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.d.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}
func verifyEd25519KeyPair(kp KeyPair, e fromPrivateExpectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*ed25519.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *rsa.PrivateKey")
}
publicKey, err := gossh.NewPublicKey(pk.Public())
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.d.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.d.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"crypto/rsa"
gossh "golang.org/x/crypto/ssh"
)
type rsaKeyPair struct {
privateKey *rsa.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o rsaKeyPair) Type() KeyPairType {
return Rsa
}
func (o rsaKeyPair) Bits() int {
return o.privateKey.N.BitLen()
}
func (o rsaKeyPair) Name() string {
return o.name
}
func (o rsaKeyPair) Description() string {
return description(o)
}
func (o rsaKeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o rsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}