140 lines
3.4 KiB
Go
140 lines
3.4 KiB
Go
package common
|
|
|
|
import (
|
|
"context"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/service/ec2"
|
|
"github.com/hashicorp/packer/packer"
|
|
)
|
|
|
|
func TestStepCreateSSMTunnel_Run(t *testing.T) {
|
|
mockSvc := MockSSMSvc{}
|
|
config := SSMDriverConfig{
|
|
SvcClient: &mockSvc,
|
|
SvcEndpoint: "example.com",
|
|
}
|
|
|
|
mockDriver := NewSSMDriver(config)
|
|
mockDriver.pluginCmdFunc = MockPluginCmdFunc
|
|
|
|
state := testState()
|
|
state.Put("ui", &packer.NoopUi{})
|
|
state.Put("instance", &ec2.Instance{InstanceId: aws.String("i-something")})
|
|
|
|
step := StepCreateSSMTunnel{
|
|
driver: mockDriver,
|
|
}
|
|
|
|
step.Run(context.Background(), state)
|
|
|
|
err := state.Get("error")
|
|
if err != nil {
|
|
err = err.(error)
|
|
t.Fatalf("the call to Run failed with an error when it should've executed: %v", err)
|
|
}
|
|
|
|
if mockSvc.StartSessionCalled {
|
|
t.Errorf("StartSession should not be called when SSMAgentEnabled is false")
|
|
}
|
|
|
|
// Run when SSMAgentEnabled is true
|
|
step.SSMAgentEnabled = true
|
|
step.Run(context.Background(), state)
|
|
|
|
err = state.Get("error")
|
|
if err != nil {
|
|
err = err.(error)
|
|
t.Fatalf("the call to Run failed with an error when it should've executed: %v", err)
|
|
}
|
|
|
|
if !mockSvc.StartSessionCalled {
|
|
t.Errorf("calling run with the correct inputs should call StartSession")
|
|
}
|
|
|
|
step.Cleanup(state)
|
|
if !mockSvc.TerminateSessionCalled {
|
|
t.Errorf("calling cleanup on a successful run should call TerminateSession")
|
|
}
|
|
}
|
|
|
|
func TestStepCreateSSMTunnel_Cleanup(t *testing.T) {
|
|
mockSvc := MockSSMSvc{}
|
|
config := SSMDriverConfig{
|
|
SvcClient: &mockSvc,
|
|
SvcEndpoint: "example.com",
|
|
}
|
|
|
|
mockDriver := NewSSMDriver(config)
|
|
mockDriver.pluginCmdFunc = MockPluginCmdFunc
|
|
|
|
step := StepCreateSSMTunnel{
|
|
SSMAgentEnabled: true,
|
|
driver: mockDriver,
|
|
}
|
|
|
|
state := testState()
|
|
state.Put("ui", &packer.NoopUi{})
|
|
state.Put("instance", &ec2.Instance{InstanceId: aws.String("i-something")})
|
|
|
|
step.Cleanup(state)
|
|
|
|
if mockSvc.TerminateSessionCalled {
|
|
t.Fatalf("calling cleanup on a non started session should not call TerminateSession")
|
|
}
|
|
|
|
}
|
|
|
|
func TestStepCreateSSMTunnel_BuildTunnelInputForInstance(t *testing.T) {
|
|
step := StepCreateSSMTunnel{
|
|
Region: "region",
|
|
LocalPortNumber: 8001,
|
|
RemotePortNumber: 22,
|
|
SSMAgentEnabled: true,
|
|
}
|
|
|
|
input := step.BuildTunnelInputForInstance("i-something")
|
|
|
|
target := aws.StringValue(input.Target)
|
|
if target != "i-something" {
|
|
t.Errorf("input should contain instance id as target but it got %q", target)
|
|
}
|
|
|
|
params := map[string][]*string{
|
|
"portNumber": []*string{aws.String("22")},
|
|
"localPortNumber": []*string{aws.String("8001")},
|
|
}
|
|
if !reflect.DeepEqual(input.Parameters, params) {
|
|
t.Errorf("input should contain the expected port parameters but it got %v", input.Parameters)
|
|
}
|
|
|
|
}
|
|
|
|
func TestStepCreateSSMTunnel_ConfigureLocalHostPort(t *testing.T) {
|
|
tt := []struct {
|
|
Name string
|
|
Step StepCreateSSMTunnel
|
|
PortCheck func(int) bool
|
|
}{
|
|
{"WithLocalPortNumber", StepCreateSSMTunnel{LocalPortNumber: 9001}, func(port int) bool { return port == 9001 }},
|
|
{"WithNoLocalPortNumber", StepCreateSSMTunnel{}, func(port int) bool { return port >= 8000 && port <= 9000 }},
|
|
}
|
|
|
|
for _, tc := range tt {
|
|
tc := tc
|
|
t.Run(tc.Name, func(t *testing.T) {
|
|
step := tc.Step
|
|
if err := step.ConfigureLocalHostPort(context.TODO()); err != nil {
|
|
t.Errorf("failed to configure a port on localhost")
|
|
}
|
|
|
|
if !tc.PortCheck(step.LocalPortNumber) {
|
|
t.Errorf("failed to configure a port on localhost")
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|