Merge pull request #3865 from dpiddy/amazon-ssh-retry-fix

amazon/common/ssh: fix saving of instance to state
This commit is contained in:
Rickard von Essen 2016-09-12 15:04:35 +02:00 committed by GitHub
commit 120b60ae37
2 changed files with 134 additions and 5 deletions

View File

@ -10,17 +10,28 @@ import (
"golang.org/x/crypto/ssh"
)
type ec2Describer interface {
DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
}
var (
// modified in tests
sshHostSleepDuration = time.Second
)
// SSHHost returns a function that can be given to the SSH communicator
// for determining the SSH address based on the instance DNS name.
func SSHHost(e *ec2.EC2, private bool) func(multistep.StateBag) (string, error) {
func SSHHost(e ec2Describer, private bool) func(multistep.StateBag) (string, error) {
return func(state multistep.StateBag) (string, error) {
for j := 0; j < 2; j++ {
const tries = 2
// <= with current structure to check result of describing `tries` times
for j := 0; j <= tries; j++ {
var host string
i := state.Get("instance").(*ec2.Instance)
if i.VpcId != nil && *i.VpcId != "" {
if i.PublicIpAddress != nil && *i.PublicIpAddress != "" && !private {
host = *i.PublicIpAddress
} else {
} else if i.PrivateIpAddress != nil && *i.PrivateIpAddress != "" {
host = *i.PrivateIpAddress
}
} else if i.PublicDnsName != nil && *i.PublicDnsName != "" {
@ -42,8 +53,8 @@ func SSHHost(e *ec2.EC2, private bool) func(multistep.StateBag) (string, error)
return "", fmt.Errorf("instance not found: %s", *i.InstanceId)
}
state.Put("instance", &r.Reservations[0].Instances[0])
time.Sleep(1 * time.Second)
state.Put("instance", r.Reservations[0].Instances[0])
time.Sleep(sshHostSleepDuration)
}
return "", errors.New("couldn't determine IP address for instance")

View File

@ -0,0 +1,118 @@
package common
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/mitchellh/multistep"
)
const (
privateIP = "10.0.0.1"
publicIP = "192.168.1.1"
publicDNS = "public.dns.test"
)
func TestSSHHost(t *testing.T) {
origSshHostSleepDuration := sshHostSleepDuration
defer func() { sshHostSleepDuration = origSshHostSleepDuration }()
sshHostSleepDuration = 0
var cases = []struct {
allowTries int
vpcId string
private bool
ok bool
wantHost string
}{
{1, "", false, true, publicDNS},
{1, "", true, true, publicDNS},
{1, "vpc-id", false, true, publicIP},
{1, "vpc-id", true, true, privateIP},
{2, "", false, true, publicDNS},
{2, "", true, true, publicDNS},
{2, "vpc-id", false, true, publicIP},
{2, "vpc-id", true, true, privateIP},
{3, "", false, false, ""},
{3, "", true, false, ""},
{3, "vpc-id", false, false, ""},
{3, "vpc-id", true, false, ""},
}
for _, c := range cases {
testSSHHost(t, c.allowTries, c.vpcId, c.private, c.ok, c.wantHost)
}
}
func testSSHHost(t *testing.T, allowTries int, vpcId string, private, ok bool, wantHost string) {
t.Logf("allowTries=%d vpcId=%s private=%t ok=%t wantHost=%q", allowTries, vpcId, private, ok, wantHost)
e := &fakeEC2Describer{
allowTries: allowTries,
vpcId: vpcId,
privateIP: privateIP,
publicIP: publicIP,
publicDNS: publicDNS,
}
f := SSHHost(e, private)
st := &multistep.BasicStateBag{}
st.Put("instance", &ec2.Instance{
InstanceId: aws.String("instance-id"),
})
host, err := f(st)
if e.tries > allowTries {
t.Fatalf("got %d ec2 DescribeInstances tries, want %d", e.tries, allowTries)
}
switch {
case ok && err != nil:
t.Fatalf("expected no error, got %+v", err)
case !ok && err == nil:
t.Fatalf("expected error, got none and host %s", host)
}
if host != wantHost {
t.Fatalf("got host %s, want %s", host, wantHost)
}
}
type fakeEC2Describer struct {
allowTries int
tries int
vpcId string
privateIP, publicIP, publicDNS string
}
func (d *fakeEC2Describer) DescribeInstances(in *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
d.tries++
instance := &ec2.Instance{
InstanceId: aws.String("instance-id"),
}
if d.vpcId != "" {
instance.VpcId = aws.String(d.vpcId)
}
if d.tries >= d.allowTries {
instance.PublicIpAddress = aws.String(d.publicIP)
instance.PrivateIpAddress = aws.String(d.privateIP)
instance.PublicDnsName = aws.String(d.publicDNS)
}
out := &ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{
{
Instances: []*ec2.Instance{instance},
},
},
}
return out, nil
}