diff --git a/builder/amazon/common/ssm_driver_test.go b/builder/amazon/common/ssm_driver_test.go deleted file mode 100644 index c28844b4d..000000000 --- a/builder/amazon/common/ssm_driver_test.go +++ /dev/null @@ -1,188 +0,0 @@ -package common - -import ( - "context" - "fmt" - "reflect" - "strings" - "testing" -) - -func NewSSMDriverWithMockSvc(svc *MockSSMSvc) *SSMDriver { - config := SSMDriverConfig{ - SvcClient: svc, - Region: "east", - ProfileName: "default", - SvcEndpoint: "example.com", - } - - driver := SSMDriver{ - SSMDriverConfig: config, - pluginCmdFunc: func(ctx context.Context) error { return nil }, - } - - return &driver -} -func TestSSMDriver_StartSession(t *testing.T) { - mockSvc := MockSSMSvc{} - driver := NewSSMDriverWithMockSvc(&mockSvc) - - if driver.SvcClient == nil { - t.Fatalf("SvcClient for driver should not be nil") - } - - session, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance")) - if err != nil { - t.Fatalf("calling StartSession should not error but got %v", err) - } - - if !mockSvc.StartSessionCalled { - t.Fatalf("expected test to call ssm mocks but didn't") - } - - if session == nil { - t.Errorf("expected session to be set after a successful call to StartSession") - } - - if !reflect.DeepEqual(session, MockStartSessionOutput()) { - t.Errorf("expected session to be %v but got %v", MockStartSessionOutput(), session) - } -} - -func TestSSMDriver_StartSessionWithError(t *testing.T) { - mockSvc := MockSSMSvc{StartSessionError: fmt.Errorf("bogus error")} - driver := NewSSMDriverWithMockSvc(&mockSvc) - - if driver.SvcClient == nil { - t.Fatalf("SvcClient for driver should not be nil") - } - - session, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance")) - if err == nil { - t.Fatalf("StartSession should have thrown an error but didn't") - } - - if !mockSvc.StartSessionCalled { - t.Errorf("expected test to call StartSession mock but didn't") - } - - if session != nil { - t.Errorf("expected session to be nil after a bad StartSession call, but got %v", session) - } -} - -func TestSSMDriver_StopSession(t *testing.T) { - mockSvc := MockSSMSvc{} - driver := NewSSMDriverWithMockSvc(&mockSvc) - - if driver.SvcClient == nil { - t.Fatalf("SvcClient for driver should not be nil") - } - - // Calling StopSession before StartSession should fail - err := driver.StopSession() - if err == nil { - t.Fatalf("calling StopSession() on a driver that has no started session should fail") - } - - if driver.session != nil { - t.Errorf("expected session to be default to nil") - } - - if mockSvc.TerminateSessionCalled { - t.Fatalf("a call to TerminateSession should not occur when there is no valid SSM session") - } - - // Lets try calling start session, then stopping to see what happens. - session, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance")) - if err != nil { - t.Fatalf("calling StartSession should not error but got %v", err) - } - - if !mockSvc.StartSessionCalled { - t.Fatalf("expected test to call StartSession mock but didn't") - } - - if session == nil || driver.session != session { - t.Errorf("expected session to be set after a successful call to StartSession") - } - - if !reflect.DeepEqual(session, MockStartSessionOutput()) { - t.Errorf("expected session to be %v but got %v", MockStartSessionOutput(), session) - } - - err = driver.StopSession() - if err != nil { - t.Errorf("calling StopSession() on a driver on a started session should not fail") - } - - if !mockSvc.TerminateSessionCalled { - t.Fatalf("expected test to call StopSession mock but didn't") - } - -} - -func TestSSMDriver_Args(t *testing.T) { - tt := []struct { - Name string - ProfileName string - SkipStartSession bool - ErrorExpected bool - }{ - { - Name: "NilSession", - SkipStartSession: true, - ErrorExpected: true, - }, - { - Name: "NonNilSession", - ErrorExpected: false, - }, - { - Name: "SessionWithProfileName", - ProfileName: "default", - ErrorExpected: false, - }, - } - - for _, tc := range tt { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - mockSvc := MockSSMSvc{} - driver := NewSSMDriverWithMockSvc(&mockSvc) - driver.ProfileName = tc.ProfileName - - if driver.SvcClient == nil { - t.Fatalf("svcclient for driver should not be nil") - } - - if !tc.SkipStartSession { - _, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance")) - if err != nil { - t.Fatalf("got an error when calling StartSession %v", err) - } - } - - args, err := driver.Args() - if tc.ErrorExpected && err == nil { - t.Fatalf("Driver.Args with a %q should have failed but instead no error was returned", tc.Name) - } - - if tc.ErrorExpected { - return - } - - if err != nil { - t.Fatalf("got an error when it should've worked %v", err) - } - - // validate launch script - expectedArgString := fmt.Sprintf(`{"SessionId":"packerid","StreamUrl":"http://packer.io","TokenValue":"packer-token"} east StartSession %s {"DocumentName":"AWS-StartPortForwardingSession","Parameters":{"localPortNumber":["8001"],"portNumber":["22"]},"Target":"fakeinstance"} example.com`, tc.ProfileName) - argString := strings.Join(args, " ") - if argString != expectedArgString { - t.Errorf("Expected launch script to be %q but got %q", expectedArgString, argString) - } - - }) - } -} diff --git a/builder/amazon/common/step_create_ssm_tunnel_test.go b/builder/amazon/common/step_create_ssm_tunnel_test.go deleted file mode 100644 index 531f24765..000000000 --- a/builder/amazon/common/step_create_ssm_tunnel_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package common - -import ( - "context" - "reflect" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/hashicorp/packer/packer" -) - -func TestStepCreateSSMTunnel_Run(t *testing.T) { - mockSvc := MockSSMSvc{} - config := SSMDriverConfig{ - SvcClient: &mockSvc, - SvcEndpoint: "example.com", - } - - mockDriver := NewSSMDriver(config) - mockDriver.pluginCmdFunc = MockPluginCmdFunc - - state := testState() - state.Put("ui", &packer.NoopUi{}) - state.Put("instance", &ec2.Instance{InstanceId: aws.String("i-something")}) - - step := StepCreateSSMTunnel{ - driver: mockDriver, - } - - step.Run(context.Background(), state) - - err := state.Get("error") - if err != nil { - err = err.(error) - t.Fatalf("the call to Run failed with an error when it should've executed: %v", err) - } - - if mockSvc.StartSessionCalled { - t.Errorf("StartSession should not be called when SSMAgentEnabled is false") - } - - // Run when SSMAgentEnabled is true - step.SSMAgentEnabled = true - step.Run(context.Background(), state) - - err = state.Get("error") - if err != nil { - err = err.(error) - t.Fatalf("the call to Run failed with an error when it should've executed: %v", err) - } - - if !mockSvc.StartSessionCalled { - t.Errorf("calling run with the correct inputs should call StartSession") - } - - step.Cleanup(state) - if !mockSvc.TerminateSessionCalled { - t.Errorf("calling cleanup on a successful run should call TerminateSession") - } -} - -func TestStepCreateSSMTunnel_Cleanup(t *testing.T) { - mockSvc := MockSSMSvc{} - config := SSMDriverConfig{ - SvcClient: &mockSvc, - SvcEndpoint: "example.com", - } - - mockDriver := NewSSMDriver(config) - mockDriver.pluginCmdFunc = MockPluginCmdFunc - - step := StepCreateSSMTunnel{ - SSMAgentEnabled: true, - driver: mockDriver, - } - - state := testState() - state.Put("ui", &packer.NoopUi{}) - state.Put("instance", &ec2.Instance{InstanceId: aws.String("i-something")}) - - step.Cleanup(state) - - if mockSvc.TerminateSessionCalled { - t.Fatalf("calling cleanup on a non started session should not call TerminateSession") - } - -} - -func TestStepCreateSSMTunnel_BuildTunnelInputForInstance(t *testing.T) { - step := StepCreateSSMTunnel{ - Region: "region", - LocalPortNumber: 8001, - RemotePortNumber: 22, - SSMAgentEnabled: true, - } - - input := step.BuildTunnelInputForInstance("i-something") - - target := aws.StringValue(input.Target) - if target != "i-something" { - t.Errorf("input should contain instance id as target but it got %q", target) - } - - params := map[string][]*string{ - "portNumber": []*string{aws.String("22")}, - "localPortNumber": []*string{aws.String("8001")}, - } - if !reflect.DeepEqual(input.Parameters, params) { - t.Errorf("input should contain the expected port parameters but it got %v", input.Parameters) - } - -} - -func TestStepCreateSSMTunnel_ConfigureLocalHostPort(t *testing.T) { - tt := []struct { - Name string - Step StepCreateSSMTunnel - PortCheck func(int) bool - }{ - {"WithLocalPortNumber", StepCreateSSMTunnel{LocalPortNumber: 9001}, func(port int) bool { return port == 9001 }}, - {"WithNoLocalPortNumber", StepCreateSSMTunnel{}, func(port int) bool { return port >= 8000 && port <= 9000 }}, - } - - for _, tc := range tt { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - step := tc.Step - if err := step.ConfigureLocalHostPort(context.TODO()); err != nil { - t.Errorf("failed to configure a port on localhost") - } - - if !tc.PortCheck(step.LocalPortNumber) { - t.Errorf("failed to configure a port on localhost") - } - }) - } - -}