packer-cn/builder/amazon/common/ssm_driver_test.go

189 lines
4.9 KiB
Go

package common
import (
"context"
"fmt"
"reflect"
"strings"
"testing"
)
func NewSSMDriverWithMockSvc(svc *MockSSMSvc) *SSMDriver {
config := SSMDriverConfig{
SvcClient: svc,
Region: "east",
ProfileName: "default",
SvcEndpoint: "example.com",
}
driver := SSMDriver{
SSMDriverConfig: config,
pluginCmdFunc: func(ctx context.Context) error { return nil },
}
return &driver
}
func TestSSMDriver_StartSession(t *testing.T) {
mockSvc := MockSSMSvc{}
driver := NewSSMDriverWithMockSvc(&mockSvc)
if driver.SvcClient == nil {
t.Fatalf("SvcClient for driver should not be nil")
}
session, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance"))
if err != nil {
t.Fatalf("calling StartSession should not error but got %v", err)
}
if !mockSvc.StartSessionCalled {
t.Fatalf("expected test to call ssm mocks but didn't")
}
if session == nil {
t.Errorf("expected session to be set after a successful call to StartSession")
}
if !reflect.DeepEqual(session, MockStartSessionOutput()) {
t.Errorf("expected session to be %v but got %v", MockStartSessionOutput(), session)
}
}
func TestSSMDriver_StartSessionWithError(t *testing.T) {
mockSvc := MockSSMSvc{StartSessionError: fmt.Errorf("bogus error")}
driver := NewSSMDriverWithMockSvc(&mockSvc)
if driver.SvcClient == nil {
t.Fatalf("SvcClient for driver should not be nil")
}
session, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance"))
if err == nil {
t.Fatalf("StartSession should have thrown an error but didn't")
}
if !mockSvc.StartSessionCalled {
t.Errorf("expected test to call StartSession mock but didn't")
}
if session != nil {
t.Errorf("expected session to be nil after a bad StartSession call, but got %v", session)
}
}
func TestSSMDriver_StopSession(t *testing.T) {
mockSvc := MockSSMSvc{}
driver := NewSSMDriverWithMockSvc(&mockSvc)
if driver.SvcClient == nil {
t.Fatalf("SvcClient for driver should not be nil")
}
// Calling StopSession before StartSession should fail
err := driver.StopSession()
if err == nil {
t.Fatalf("calling StopSession() on a driver that has no started session should fail")
}
if driver.session != nil {
t.Errorf("expected session to be default to nil")
}
if mockSvc.TerminateSessionCalled {
t.Fatalf("a call to TerminateSession should not occur when there is no valid SSM session")
}
// Lets try calling start session, then stopping to see what happens.
session, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance"))
if err != nil {
t.Fatalf("calling StartSession should not error but got %v", err)
}
if !mockSvc.StartSessionCalled {
t.Fatalf("expected test to call StartSession mock but didn't")
}
if session == nil || driver.session != session {
t.Errorf("expected session to be set after a successful call to StartSession")
}
if !reflect.DeepEqual(session, MockStartSessionOutput()) {
t.Errorf("expected session to be %v but got %v", MockStartSessionOutput(), session)
}
err = driver.StopSession()
if err != nil {
t.Errorf("calling StopSession() on a driver on a started session should not fail")
}
if !mockSvc.TerminateSessionCalled {
t.Fatalf("expected test to call StopSession mock but didn't")
}
}
func TestSSMDriver_Args(t *testing.T) {
tt := []struct {
Name string
ProfileName string
SkipStartSession bool
ErrorExpected bool
}{
{
Name: "NilSession",
SkipStartSession: true,
ErrorExpected: true,
},
{
Name: "NonNilSession",
ErrorExpected: false,
},
{
Name: "SessionWithProfileName",
ProfileName: "default",
ErrorExpected: false,
},
}
for _, tc := range tt {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
mockSvc := MockSSMSvc{}
driver := NewSSMDriverWithMockSvc(&mockSvc)
driver.ProfileName = tc.ProfileName
if driver.SvcClient == nil {
t.Fatalf("svcclient for driver should not be nil")
}
if !tc.SkipStartSession {
_, err := driver.StartSession(context.TODO(), MockStartSessionInput("fakeinstance"))
if err != nil {
t.Fatalf("got an error when calling StartSession %v", err)
}
}
args, err := driver.Args()
if tc.ErrorExpected && err == nil {
t.Fatalf("Driver.Args with a %q should have failed but instead no error was returned", tc.Name)
}
if tc.ErrorExpected {
return
}
if err != nil {
t.Fatalf("got an error when it should've worked %v", err)
}
// validate launch script
expectedArgString := fmt.Sprintf(`{"SessionId":"packerid","StreamUrl":"http://packer.io","TokenValue":"packer-token"} east StartSession %s {"DocumentName":"AWS-StartPortForwardingSession","Parameters":{"localPortNumber":["8001"],"portNumber":["22"]},"Target":"fakeinstance"} example.com`, tc.ProfileName)
argString := strings.Join(args, " ")
if argString != expectedArgString {
t.Errorf("Expected launch script to be %q but got %q", expectedArgString, argString)
}
})
}
}