Merge pull request #4 from stephen-fox/refactor-ssh-key-pair-logic

Initial take on code review feedback from @azr.
This commit is contained in:
Chris Marget 2019-02-27 17:05:44 -05:00 committed by GitHub
commit 0f1bde760c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 383 additions and 744 deletions

View File

@ -36,19 +36,19 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis
return multistep.ActionHalt
}
kp, err := ssh.NewKeyPairBuilder().
SetPrivateKey(privateKeyBytes).
SetName(fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID())).
Build()
kp, err := ssh.KeyPairFromPrivateKey(ssh.FromPrivateKeyConfig{
RawPrivateKeyPemBlock: privateKeyBytes,
Name: fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID()),
})
if err != nil {
state.Put("error", err)
return multistep.ActionHalt
}
s.Comm.SSHPrivateKey = privateKeyBytes
s.Comm.SSHKeyPairName = kp.Name()
s.Comm.SSHTemporaryKeyPairName = kp.Name()
s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine(ssh.UnixNewLine)
s.Comm.SSHKeyPairName = kp.Name
s.Comm.SSHTemporaryKeyPairName = kp.Name
s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine
return multistep.ActionContinue
}
@ -60,21 +60,21 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis
ui.Say("Creating ephemeral key pair for SSH communicator...")
kp, err := ssh.NewKeyPairBuilder().
SetName(fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID())).
Build()
kp, err := ssh.NewKeyPair(ssh.CreateKeyPairConfig{
Name: fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID()),
})
if err != nil {
state.Put("error", fmt.Errorf("Error creating temporary keypair: %s", err))
return multistep.ActionHalt
}
s.Comm.SSHKeyPairName = kp.Name()
s.Comm.SSHTemporaryKeyPairName = kp.Name()
s.Comm.SSHPrivateKey = kp.PrivateKeyPemBlock()
s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine(ssh.UnixNewLine)
s.Comm.SSHKeyPairName = kp.Name
s.Comm.SSHTemporaryKeyPairName = kp.Name
s.Comm.SSHPrivateKey = kp.PrivateKeyPemBlock
s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine
s.Comm.SSHClearAuthorizedKeys = true
ui.Say(fmt.Sprintf("Created ephemeral SSH key pair of type %s", kp.Description()))
ui.Say("Created ephemeral SSH key pair for communicator")
// If we're in debug mode, output the private key to the working
// directory.
@ -90,7 +90,7 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis
defer f.Close()
// Write the key out
if _, err := f.Write(kp.PrivateKeyPemBlock()); err != nil {
if _, err := f.Write(kp.PrivateKeyPemBlock); err != nil {
state.Put("error", fmt.Errorf("Error saving debug key: %s", err))
return multistep.ActionHalt
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"crypto/dsa"
gossh "golang.org/x/crypto/ssh"
)
type dsaKeyPair struct {
privateKey *dsa.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o dsaKeyPair) Type() KeyPairType {
return Dsa
}
func (o dsaKeyPair) Bits() int {
return 1024
}
func (o dsaKeyPair) Name() string {
return o.name
}
func (o dsaKeyPair) Description() string {
return description(o)
}
func (o dsaKeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o dsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"crypto/ecdsa"
gossh "golang.org/x/crypto/ssh"
)
type ecdsaKeyPair struct {
privateKey *ecdsa.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o ecdsaKeyPair) Type() KeyPairType {
return Ecdsa
}
func (o ecdsaKeyPair) Bits() int {
return o.privateKey.Curve.Params().BitSize
}
func (o ecdsaKeyPair) Name() string {
return o.name
}
func (o ecdsaKeyPair) Description() string {
return description(o)
}
func (o ecdsaKeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o ecdsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"golang.org/x/crypto/ed25519"
gossh "golang.org/x/crypto/ssh"
)
type ed25519KeyPair struct {
privateKey *ed25519.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o ed25519KeyPair) Type() KeyPairType {
return Ed25519
}
func (o ed25519KeyPair) Bits() int {
return 256
}
func (o ed25519KeyPair) Name() string {
return o.name
}
func (o ed25519KeyPair) Description() string {
return description(o)
}
func (o ed25519KeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o ed25519KeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}

View File

@ -9,9 +9,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"strconv"
"strings"
"golang.org/x/crypto/ed25519"
@ -38,175 +36,130 @@ func (o KeyPairType) String() string {
return string(o)
}
const (
// UnixNewLine is a unix new line.
UnixNewLine NewLineOption = "\n"
// CreateKeyPairConfig describes how an SSH key pair should be created.
type CreateKeyPairConfig struct {
// Type describes the key pair's type.
Type KeyPairType
// WindowsNewLine is a Windows new line.
WindowsNewLine NewLineOption = "\r\n"
// Bits represents the key pair's bits of entropy. E.g., 4096 for
// a 4096 bit RSA key pair, or 521 for a ECDSA key pair with a
// 521-bit curve.
Bits int
// NoNewLine will not append a new line.
NoNewLine NewLineOption = ""
)
// NewLineOption specifies the type of new line to append to a string.
// See the 'const' block for choices.
type NewLineOption string
func (o NewLineOption) String() string {
return string(o)
// Name is the resulting key pair's name. This is used to identify
// the key pair in the SSH server's 'authorized_keys'.
Name string
}
func (o NewLineOption) Bytes() []byte {
return []byte(o)
// FromPrivateKeyConfig describes how an SSH key pair should be loaded from an
// existing private key.
type FromPrivateKeyConfig struct {
// RawPrivateKeyPemBlock is the raw private key that the key pair
// should be loaded from.
RawPrivateKeyPemBlock []byte
// Name is the resulting key pair's name. This is used to identify
// the key pair in the SSH server's 'authorized_keys'.
Name string
}
// KeyPairBuilder builds SSH key pairs.
// It can generate new keys of type RSA and ECDSA.
// It can parse user supplied keys of type DSA, RSA, ECDSA,
// and ED25519.
type KeyPairBuilder interface {
// SetType sets the key pair type.
SetType(KeyPairType) KeyPairBuilder
// KeyPair represents an SSH key pair.
// TODO: Maybe a field for a description? Maybe save the type?
type KeyPair struct {
// PrivateKeyPemBlock represents the key pair's private key in
// ASN.1 Distinguished Encoding Rules (DER) format in a
// Privacy-Enhanced Mail (PEM) block.
PrivateKeyPemBlock []byte
// SetBits sets the key pair's bits of entropy.
SetBits(int) KeyPairBuilder
// PublicKeyAuthorizedKeysLine represents the key pair's public key
// as a line in OpenSSH authorized_keys.
PublicKeyAuthorizedKeysLine []byte
// SetName sets the name of the key pair. This is primarily
// used to identify the public key in the authorized_keys file.
SetName(string) KeyPairBuilder
// SetPrivateKey takes an existing private key in PEM format.
// It overrides key generation details specified by SetType()
// and SetBits().
SetPrivateKey([]byte) KeyPairBuilder
// Build returns a SSH key pair.
//
// The following defaults are used if not specified:
// Default type: ECDSA
// Default bits of entropy:
// - RSA: 4096
// - ECDSA: 521
// Default name: (empty string)
Build() (KeyPair, error)
// Name is the key pair's name. This is used to identify
// the key pair in the SSH server's 'authorized_keys'.
Name string
}
type defaultKeyPairBuilder struct {
// kind describes the resulting key pair's type.
kind KeyPairType
// bits is the resulting key pair's bits of entropy.
bits int
// name is the resulting key pair's name.
name string
// privatePemBytes is the supplied key data when the builder
// is working from a preallocated key.
privatePemBytes []byte
}
func (o *defaultKeyPairBuilder) SetType(kind KeyPairType) KeyPairBuilder {
o.kind = kind
return o
}
func (o *defaultKeyPairBuilder) SetBits(bits int) KeyPairBuilder {
o.bits = bits
return o
}
func (o *defaultKeyPairBuilder) SetName(name string) KeyPairBuilder {
o.name = name
return o
}
func (o *defaultKeyPairBuilder) SetPrivateKey(privateBytes []byte) KeyPairBuilder {
o.privatePemBytes = privateBytes
return o
}
func (o *defaultKeyPairBuilder) Build() (KeyPair, error) {
if o.privatePemBytes != nil {
return o.preallocatedKeyPair()
}
switch o.kind {
case Rsa:
return o.newRsaKeyPair()
case Ecdsa, Default:
return o.newEcdsaKeyPair()
}
return nil, fmt.Errorf("Cannot generate SSH key pair - unsupported key pair type: %s", o.kind.String())
}
// preallocatedKeyPair returns an SSH key pair based on user
// supplied PEM data.
func (o *defaultKeyPairBuilder) preallocatedKeyPair() (KeyPair, error) {
privateKey, err := gossh.ParseRawPrivateKey(o.privatePemBytes)
// KeyPairFromPrivateKey returns a KeyPair loaded from an existing private key.
//
// Supported key pair types include:
// - DSA
// - ECDSA
// - ED25519
// - RSA
func KeyPairFromPrivateKey(config FromPrivateKeyConfig) (KeyPair, error) {
privateKey, err := gossh.ParseRawPrivateKey(config.RawPrivateKeyPemBlock)
if err != nil {
return nil, err
return KeyPair{}, err
}
switch pk := privateKey.(type) {
case *rsa.PrivateKey:
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return nil, err
return KeyPair{}, err
}
return &rsaKeyPair{
privateKey: pk,
publicKey: publicKey,
name: o.name,
privatePemBlock: o.privatePemBytes,
return KeyPair{
PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
}, nil
case *ecdsa.PrivateKey:
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return nil, err
return KeyPair{}, err
}
return &ecdsaKeyPair{
privateKey: pk,
publicKey: publicKey,
name: o.name,
privatePemBlock: o.privatePemBytes,
return KeyPair{
PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
}, nil
case *dsa.PrivateKey:
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return nil, err
return KeyPair{}, err
}
return &dsaKeyPair{
privateKey: pk,
publicKey: publicKey,
name: o.name,
privatePemBlock: o.privatePemBytes,
return KeyPair{
PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
}, nil
case *ed25519.PrivateKey:
publicKey, err := gossh.NewPublicKey(pk.Public())
if err != nil {
return nil, err
return KeyPair{}, err
}
return &ed25519KeyPair{
privateKey: pk,
publicKey: publicKey,
name: o.name,
privatePemBlock: o.privatePemBytes,
return KeyPair{
PrivateKeyPemBlock: config.RawPrivateKeyPemBlock,
PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name),
}, nil
}
return nil, fmt.Errorf("Cannot parse preallocated key pair - unknown ssh key pair type")
return KeyPair{}, fmt.Errorf("Cannot parse existing SSH key pair - unknown key pair type")
}
// NewKeyPair generates a new SSH key pair using the specified
// CreateKeyPairConfig.
func NewKeyPair(config CreateKeyPairConfig) (KeyPair, error) {
if config.Type == Default {
config.Type = Ecdsa
}
switch config.Type {
case Ecdsa:
return newEcdsaKeyPair(config)
case Rsa:
return newRsaKeyPair(config)
}
return KeyPair{}, fmt.Errorf("Unable to generate new key pair, type %s is not supported",
config.Type.String())
}
// newEcdsaKeyPair returns a new ECDSA SSH key pair.
func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) {
func newEcdsaKeyPair(config CreateKeyPairConfig) (KeyPair, error) {
var curve elliptic.Curve
switch o.bits {
switch config.Bits {
case 0:
o.bits = 521
config.Bits = 521
fallthrough
case 521:
curve = elliptic.P521()
@ -216,26 +169,24 @@ func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) {
curve = elliptic.P256()
case 224:
// Not supported by "golang.org/x/crypto/ssh".
return &ecdsaKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " +
strconv.Itoa(o.bits) + " bits")
return KeyPair{}, fmt.Errorf("golang.org/x/crypto/ssh does not support %d bits", config.Bits)
default:
return &ecdsaKeyPair{}, errors.New("crypto/elliptic does not support " +
strconv.Itoa(o.bits) + " bits")
return KeyPair{}, fmt.Errorf("crypto/elliptic does not support %d bits", config.Bits)
}
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return &ecdsaKeyPair{}, err
return KeyPair{}, err
}
sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return &ecdsaKeyPair{}, err
return KeyPair{}, err
}
privateRaw, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return &ecdsaKeyPair{}, err
return KeyPair{}, err
}
privatePem, err := rawPemBlock(&pem.Block{
@ -244,31 +195,30 @@ func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) {
Bytes: privateRaw,
})
if err != nil {
return &ecdsaKeyPair{}, err
return KeyPair{}, err
}
return &ecdsaKeyPair{
privateKey: privateKey,
publicKey: sshPublicKey,
name: o.name,
privatePemBlock: privatePem,
return KeyPair{
PrivateKeyPemBlock: privatePem,
PublicKeyAuthorizedKeysLine: authorizedKeysLine(sshPublicKey, config.Name),
Name: config.Name,
}, nil
}
// newRsaKeyPair returns a new RSA SSH key pair.
func (o *defaultKeyPairBuilder) newRsaKeyPair() (KeyPair, error) {
if o.bits == 0 {
o.bits = defaultRsaBits
func newRsaKeyPair(config CreateKeyPairConfig) (KeyPair, error) {
if config.Bits == 0 {
config.Bits = defaultRsaBits
}
privateKey, err := rsa.GenerateKey(rand.Reader, o.bits)
privateKey, err := rsa.GenerateKey(rand.Reader, config.Bits)
if err != nil {
return &rsaKeyPair{}, err
return KeyPair{}, err
}
sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return &rsaKeyPair{}, err
return KeyPair{}, err
}
privatePemBlock, err := rawPemBlock(&pem.Block{
@ -277,48 +227,16 @@ func (o *defaultKeyPairBuilder) newRsaKeyPair() (KeyPair, error) {
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
if err != nil {
return &rsaKeyPair{}, err
return KeyPair{}, err
}
return &rsaKeyPair{
privateKey: privateKey,
publicKey: sshPublicKey,
name: o.name,
privatePemBlock: privatePemBlock,
return KeyPair{
PrivateKeyPemBlock: privatePemBlock,
PublicKeyAuthorizedKeysLine: authorizedKeysLine(sshPublicKey, config.Name),
Name: config.Name,
}, nil
}
// KeyPair represents a SSH key pair.
type KeyPair interface {
// Type returns the key pair's type.
Type() KeyPairType
// Bits returns the bits of entropy.
Bits() int
// Name returns the key pair's name. An empty string is
// returned if no name was specified.
Name() string
// Description returns a brief description of the key pair that
// is suitable for log messages or printing.
Description() string
// PrivateKeyPemBlock returns a slice of bytes representing
// the private key in ASN.1 Distinguished Encoding Rules (DER)
// format in a Privacy-Enhanced Mail (PEM) block.
PrivateKeyPemBlock() []byte
// PublicKeyAuthorizedKeysLine returns a slice of bytes
// representing the public key as a line in OpenSSH authorized_keys
// format with the specified new line.
PublicKeyAuthorizedKeysLine(NewLineOption) []byte
}
func NewKeyPairBuilder() KeyPairBuilder {
return &defaultKeyPairBuilder{}
}
// rawPemBlock encodes a pem.Block to a slice of bytes.
func rawPemBlock(block *pem.Block) ([]byte, error) {
buffer := bytes.NewBuffer(nil)
@ -331,26 +249,10 @@ func rawPemBlock(block *pem.Block) ([]byte, error) {
return buffer.Bytes(), nil
}
// description returns a string describing a key pair.
func description(kp KeyPair) string {
buffer := bytes.NewBuffer(nil)
buffer.WriteString(strconv.Itoa(kp.Bits()))
buffer.WriteString(" bit ")
buffer.WriteString(kp.Type().String())
if len(kp.Name()) > 0 {
buffer.WriteString(" named ")
buffer.WriteString(kp.Name())
}
return buffer.String()
}
// publicKeyAuthorizedKeysLine returns a slice of bytes representing a SSH
// public key as a line in OpenSSH authorized_keys format.
func publicKeyAuthorizedKeysLine(publicKey gossh.PublicKey, name string, nl NewLineOption) []byte {
result := gossh.MarshalAuthorizedKey(publicKey)
// authorizedKeysLine returns a slice of bytes representing an SSH public key
// as a line in OpenSSH authorized_keys format. No line break is appended.
func authorizedKeysLine(sshPublicKey gossh.PublicKey, name string) []byte {
result := gossh.MarshalAuthorizedKey(sshPublicKey)
// Remove the mandatory unix new line.
// Awful, but the go ssh library automatically appends
@ -362,14 +264,5 @@ func publicKeyAuthorizedKeysLine(publicKey gossh.PublicKey, name string, nl NewL
result = append(result, name...)
}
switch nl {
case WindowsNewLine:
result = append(result, nl.Bytes()...)
case UnixNewLine:
// This is how all the other "SSH key pair" code works in
// the different builders.
result = append(result, UnixNewLine.Bytes()...)
}
return result
}

View File

@ -2,24 +2,17 @@ package ssh
import (
"bytes"
"crypto/rand"
"errors"
"strconv"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rsa"
"fmt"
"testing"
"github.com/hashicorp/packer/common/uuid"
"golang.org/x/crypto/ed25519"
gossh "golang.org/x/crypto/ssh"
)
// expected contains the data that the key pair should contain.
type expected struct {
kind KeyPairType
bits int
desc string
name string
data []byte
}
const (
pemRsa1024 = `-----BEGIN RSA PRIVATE KEY-----
MIICWwIBAAKBgQDJEMFPpTBiWNDb3qEIPTSeEnIP8FZdBpG8njOrclcMoQQNhzZ+
@ -151,401 +144,306 @@ QBAgM=
`
)
func (o expected) matches(kp KeyPair) error {
if o.kind.String() == "" {
return errors.New("expected kind's value cannot be empty")
}
if o.bits <= 0 {
return errors.New("expected bits' value cannot be less than or equal to 0")
}
if o.desc == "" {
return errors.New("expected description's value cannot be empty")
}
if len(o.data) == 0 {
return errors.New("expected random data value cannot be nothing")
}
if kp.Type() != o.kind {
return errors.New("key pair type should be " + o.kind.String() +
" - got '" + kp.Type().String() + "'")
}
if kp.Bits() != o.bits {
return errors.New("key pair bits should be " + strconv.Itoa(o.bits) +
" - got " + strconv.Itoa(kp.Bits()))
}
if len(o.name) > 0 && kp.Name() != o.name {
return errors.New("key pair name should be '" + o.name +
"' - got '" + kp.Name() + "'")
}
if kp.Description() != o.desc {
return errors.New("key pair description should be '" +
o.desc + "' - got '" + kp.Description() + "'")
}
err := o.verifyPublicKeyAuthorizedKeysFormat(kp)
if err != nil {
return err
}
err = o.verifyKeyPair(kp)
if err != nil {
return err
}
return nil
}
func (o expected) verifyPublicKeyAuthorizedKeysFormat(kp KeyPair) error {
newLines := []NewLineOption{
UnixNewLine,
NoNewLine,
WindowsNewLine,
}
for _, nl := range newLines {
publicKeyAk := kp.PublicKeyAuthorizedKeysLine(nl)
if len(publicKeyAk) < 2 {
return errors.New("expected public key in authorized keys format to be at least 2 bytes")
}
switch nl {
case NoNewLine:
if publicKeyAk[len(publicKeyAk)-1] == '\n' {
return errors.New("public key in authorized keys format has trailing new line when none was specified")
}
case UnixNewLine:
if publicKeyAk[len(publicKeyAk)-1] != '\n' {
return errors.New("public key in authorized keys format does not have unix new line when unix was specified")
}
if string(publicKeyAk[len(publicKeyAk)-2:]) == WindowsNewLine.String() {
return errors.New("public key in authorized keys format has windows new line when unix was specified")
}
case WindowsNewLine:
if string(publicKeyAk[len(publicKeyAk)-2:]) != WindowsNewLine.String() {
return errors.New("public key in authorized keys format does not have windows new line when windows was specified")
}
}
if len(o.name) > 0 {
if len(publicKeyAk) < len(o.name) {
return errors.New("public key in authorized keys format is shorter than the key pair's name")
}
suffix := []byte{' '}
suffix = append(suffix, o.name...)
suffix = append(suffix, nl.Bytes()...)
if !bytes.HasSuffix(publicKeyAk, suffix) {
return errors.New("public key in authorized keys format with name does not have name in suffix - got '" +
string(publicKeyAk) + "'")
}
}
}
return nil
}
func (o expected) verifyKeyPair(kp KeyPair) error {
signer, err := gossh.ParsePrivateKey(kp.PrivateKeyPemBlock())
if err != nil {
return errors.New("failed to parse private key during verification - " + err.Error())
}
signature, err := signer.Sign(rand.Reader, o.data)
if err != nil {
return errors.New("failed to sign test data during verification - " + err.Error())
}
err = signer.PublicKey().Verify(o.data, signature)
if err != nil {
return errors.New("failed to verify test data - " + err.Error())
}
return nil
}
func TestDefaultKeyPairBuilder_Build_Default(t *testing.T) {
kp, err := NewKeyPairBuilder().Build()
func TestNewKeyPair_Default(t *testing.T) {
kp, err := NewKeyPair(CreateKeyPairConfig{})
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Ecdsa,
err = verifyEcdsaKeyPair(kp, expectedData{
bits: 521,
desc: "521 bit ECDSA",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
})
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultKeyPairBuilder_Build_EcdsaDefault(t *testing.T) {
kp, err := NewKeyPairBuilder().
SetType(Ecdsa).
Build()
func TestNewKeyPair_ECDSA_Default(t *testing.T) {
kp, err := NewKeyPair(CreateKeyPairConfig{
Type: Ecdsa,
})
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Ecdsa,
err = verifyEcdsaKeyPair(kp, expectedData{
bits: 521,
desc: "521 bit ECDSA",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
})
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultKeyPairBuilder_Build_EcdsaSupportedCurves(t *testing.T) {
supportedBits := []int{
521,
384,
256,
}
func TestNewKeyPair_ECDSA_Positive(t *testing.T) {
for _, bits := range []int{521, 384, 256} {
config := CreateKeyPairConfig{
Type: Ecdsa,
Bits: bits,
Name: uuid.TimeOrderedUUID(),
}
for _, bits := range supportedBits {
kp, err := NewKeyPairBuilder().
SetType(Ecdsa).
SetBits(bits).
Build()
kp, err := NewKeyPair(config)
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Ecdsa,
err = verifyEcdsaKeyPair(kp, expectedData{
bits: bits,
desc: strconv.Itoa(bits) + " bit ECDSA",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
name: config.Name,
})
if err != nil {
t.Fatal(err.Error())
}
}
}
func TestDefaultKeyPairBuilder_Build_RsaDefault(t *testing.T) {
kp, err := NewKeyPairBuilder().
SetType(Rsa).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Rsa,
bits: 4096,
desc: "4096 bit RSA",
data: []byte(uuid.TimeOrderedUUID()),
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
func TestNewKeyPair_ECDSA_Negative(t *testing.T) {
for _, bits := range []int{224, 1, 2, 3} {
_, err := NewKeyPair(CreateKeyPairConfig{
Type: Ecdsa,
Bits: bits,
})
if err == nil {
t.Fatalf("expected key pair generation to fail for %d bits", bits)
}
}
}
func TestDefaultKeyPairBuilder_Build_NamedEcdsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
func TestNewKeyPair_RSA_Positive(t *testing.T) {
for _, bits := range []int{4096, 2048} {
config := CreateKeyPairConfig{
Type: Rsa,
Bits: bits,
Name: uuid.TimeOrderedUUID(),
}
kp, err := NewKeyPairBuilder().
SetType(Ecdsa).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Ecdsa,
bits: 521,
desc: "521 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultKeyPairBuilder_Build_NamedRsa(t *testing.T) {
name := uuid.TimeOrderedUUID()
kp, err := NewKeyPairBuilder().
SetType(Rsa).
SetName(name).
Build()
if err != nil {
t.Fatal(err.Error())
}
err = expected{
kind: Rsa,
bits: 4096,
desc: "4096 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
name: name,
}.matches(kp)
if err != nil {
t.Fatal(err.Error())
}
}
func TestDefaultKeyPairBuilder_SetPrivateKey(t *testing.T) {
name := uuid.TimeOrderedUUID()
pemData := make(map[string]expected)
pemData[pemRsa1024] = expected{
bits: 1024,
kind: Rsa,
name: name,
desc: "1024 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemRsa2048] = expected{
bits: 2048,
kind: Rsa,
name: name,
desc: "2048 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshRsa1024] = expected{
bits: 1024,
kind: Rsa,
name: name,
desc: "1024 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshRsa2048] = expected{
bits: 2048,
kind: Rsa,
name: name,
desc: "2048 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemDsa] = expected{
bits: 1024,
kind: Dsa,
name: name,
desc: "1024 bit DSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemEcdsa384] = expected{
bits: 384,
kind: Ecdsa,
name: name,
desc: "384 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemEcdsa521] = expected{
bits: 521,
kind: Ecdsa,
name: name,
desc: "521 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshEd25519] = expected{
bits: 256,
kind: Ed25519,
name: name,
desc: "256 bit ED25519 named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
for s, l := range pemData {
kp, err := NewKeyPairBuilder().SetPrivateKey([]byte(s)).SetName(name).Build()
kp, err := NewKeyPair(config)
if err != nil {
t.Fatal(err)
t.Fatal(err.Error())
}
err = l.matches(kp)
err = verifyRsaKeyPair(kp, expectedData{
bits: config.Bits,
name: config.Name,
})
if err != nil {
t.Fatal(err)
t.Fatal(err.Error())
}
}
}
func TestDefaultKeyPairBuilder_SetPrivateKey_Override(t *testing.T) {
name := uuid.TimeOrderedUUID()
pemData := make(map[string]expected)
pemData[pemRsa1024] = expected{
bits: 1024,
kind: Rsa,
name: name,
desc: "1024 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemRsa2048] = expected{
bits: 2048,
kind: Rsa,
name: name,
desc: "2048 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshRsa1024] = expected{
bits: 1024,
kind: Rsa,
name: name,
desc: "1024 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshRsa2048] = expected{
bits: 2048,
kind: Rsa,
name: name,
desc: "2048 bit RSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemDsa] = expected{
bits: 1024,
kind: Dsa,
name: name,
desc: "1024 bit DSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemEcdsa384] = expected{
bits: 384,
kind: Ecdsa,
name: name,
desc: "384 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemEcdsa521] = expected{
bits: 521,
kind: Ecdsa,
name: name,
desc: "521 bit ECDSA named " + name,
data: []byte(uuid.TimeOrderedUUID()),
}
pemData[pemOpenSshEd25519] = expected{
bits: 256,
kind: Ed25519,
name: name,
desc: "256 bit ED25519 named " + name,
data: []byte(uuid.TimeOrderedUUID()),
func TestKeyPairFromPrivateKey(t *testing.T) {
m := map[string]fromPrivateExpectedData{
pemRsa1024: {
t: Rsa,
d: expectedData{
bits: 1024,
},
},
pemRsa2048: {
t: Rsa,
d: expectedData{
bits: 2048,
},
},
pemOpenSshRsa1024: {
t: Rsa,
d: expectedData{
bits: 1024,
},
},
pemOpenSshRsa2048: {
t: Rsa,
d: expectedData{
bits: 2048,
},
},
pemDsa: {
t: Dsa,
d: expectedData{
bits: 1024,
},
},
pemEcdsa384: {
t: Ecdsa,
d: expectedData{
bits: 384,
},
},
pemEcdsa521: {
t: Ecdsa,
d: expectedData{
bits: 521,
},
},
pemOpenSshEd25519: {
t: Ed25519,
d: expectedData{
bits: 256,
},
},
}
supportedKeyTypes := []KeyPairType{Rsa, Dsa}
for _, keyType := range supportedKeyTypes {
for pemString, expectedResult := range pemData {
kp, err := NewKeyPairBuilder().
SetPrivateKey([]byte(pemString)).
SetName(name).
SetType(keyType).
Build()
if err != nil {
t.Fatal(err)
}
err = expectedResult.matches(kp)
if err != nil {
t.Fatal(err)
}
for rawPrivateKey, expected := range m {
kp, err := KeyPairFromPrivateKey(FromPrivateKeyConfig{
RawPrivateKeyPemBlock: []byte(rawPrivateKey),
})
if err != nil {
t.Fatal(err.Error())
}
switch expected.t {
case Dsa:
err = verifyDsaKeyPair(kp, expected)
case Ecdsa:
err = verifyEcdsaKeyPair(kp, expected.d)
case Ed25519:
err = verifyEd25519KeyPair(kp, expected)
case Rsa:
err = verifyRsaKeyPair(kp, expected.d)
default:
err = fmt.Errorf("unexected SSH key pair type %s", expected.t.String())
}
if err != nil {
t.Fatal(err.Error())
}
}
}
type fromPrivateExpectedData struct {
t KeyPairType
d expectedData
}
type expectedData struct {
bits int
name string
}
func verifyEcdsaKeyPair(kp KeyPair, e expectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*ecdsa.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *ecdsa.PrivateKey")
}
if pk.Curve.Params().BitSize != e.bits {
return fmt.Errorf("bit size should be %d - got %d", e.bits, pk.Curve.Params().BitSize)
}
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}
func verifyRsaKeyPair(kp KeyPair, e expectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *rsa.PrivateKey")
}
if pk.N.BitLen() != e.bits {
return fmt.Errorf("bit size should be %d - got %d", e.bits, pk.N.BitLen())
}
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}
func verifyDsaKeyPair(kp KeyPair, e fromPrivateExpectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*dsa.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *rsa.PrivateKey")
}
publicKey, err := gossh.NewPublicKey(&pk.PublicKey)
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.d.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.d.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}
func verifyEd25519KeyPair(kp KeyPair, e fromPrivateExpectedData) error {
privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock)
if err != nil {
return err
}
pk, ok := privateKey.(*ed25519.PrivateKey)
if !ok {
return fmt.Errorf("private key should be *rsa.PrivateKey")
}
publicKey, err := gossh.NewPublicKey(pk.Public())
if err != nil {
return err
}
expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n"))
if len(e.d.name) > 0 {
expectedBytes = append(expectedBytes, ' ')
expectedBytes = append(expectedBytes, e.d.name...)
}
if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) {
return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'",
string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine))
}
return nil
}

View File

@ -1,38 +0,0 @@
package ssh
import (
"crypto/rsa"
gossh "golang.org/x/crypto/ssh"
)
type rsaKeyPair struct {
privateKey *rsa.PrivateKey
publicKey gossh.PublicKey
name string
privatePemBlock []byte
}
func (o rsaKeyPair) Type() KeyPairType {
return Rsa
}
func (o rsaKeyPair) Bits() int {
return o.privateKey.N.BitLen()
}
func (o rsaKeyPair) Name() string {
return o.name
}
func (o rsaKeyPair) Description() string {
return description(o)
}
func (o rsaKeyPair) PrivateKeyPemBlock() []byte {
return o.privatePemBlock
}
func (o rsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte {
return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl)
}