Add tests for driver and ssm tunnel step
This commit is contained in:
parent
e53d6aea66
commit
743df19af2
|
@ -3,52 +3,50 @@ package common
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
|
||||
"github.com/hashicorp/packer/packer"
|
||||
"github.com/aws/aws-sdk-go/service/ssm"
|
||||
)
|
||||
|
||||
const SessionManagerPluginName string = "session-manager-plugin"
|
||||
const sessionManagerPluginName string = "session-manager-plugin"
|
||||
|
||||
//sessionCommand is the AWS-SDK equivalent to the command you would specify to `aws ssm ...`
|
||||
const sessionCommand string = "StartSession"
|
||||
|
||||
type SSMDriver struct {
|
||||
Ui packer.Ui
|
||||
Ctx context.Context
|
||||
// Provided for testing purposes; if not specified it defaults to SessionManagerPluginName
|
||||
Region string
|
||||
ProfileName string
|
||||
Session *ssm.StartSessionOutput
|
||||
SessionParams ssm.StartSessionInput
|
||||
SessionEndpoint string
|
||||
// Provided for testing purposes; if not specified it defaults to sessionManagerPluginName
|
||||
PluginName string
|
||||
}
|
||||
|
||||
// StartSession starts an interactive Systems Manager session with a remote instance via the AWS session-manager-plugin
|
||||
func (s *SSMDriver) StartSession(sessionData, region, profile, params, endpoint string) error {
|
||||
func (sd *SSMDriver) StartSession(ctx context.Context) error {
|
||||
var stdout bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
|
||||
args := []string{
|
||||
sessionData,
|
||||
region,
|
||||
"StartSession",
|
||||
profile,
|
||||
params,
|
||||
endpoint,
|
||||
if sd.PluginName == "" {
|
||||
sd.PluginName = sessionManagerPluginName
|
||||
}
|
||||
|
||||
if s.PluginName == "" {
|
||||
s.PluginName = SessionManagerPluginName
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath(s.PluginName); err != nil {
|
||||
args, err := sd.Args()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error encountered validating session details: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Attempting to start session with the following args: %v", args)
|
||||
cmd := exec.CommandContext(s.Ctx, s.PluginName, args...)
|
||||
cmd := exec.CommandContext(ctx, sd.PluginName, args...)
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
err = fmt.Errorf("error encountered when calling %s: %s\nStderr: %s", s.PluginName, err, stderr.String())
|
||||
s.Ui.Error(err.Error())
|
||||
err = fmt.Errorf("error encountered when calling %s: %s\nStderr: %s", sd.PluginName, err, stderr.String())
|
||||
return err
|
||||
}
|
||||
// TODO capture logging for testing
|
||||
|
@ -56,3 +54,31 @@ func (s *SSMDriver) StartSession(sessionData, region, profile, params, endpoint
|
|||
|
||||
return nil
|
||||
}
|
||||
func (sd *SSMDriver) Args() ([]string, error) {
|
||||
if sd.Session == nil {
|
||||
return nil, fmt.Errorf("an active Amazon SSM Session is required before trying to open a session tunnel")
|
||||
}
|
||||
|
||||
// AWS session-manager-plugin requires a valid session be passed in JSON.
|
||||
sessionDetails, err := json.Marshal(sd.Session)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error encountered in reading session details %s", err)
|
||||
}
|
||||
|
||||
// AWS session-manager-plugin requires the parameters used in the session to be passed in JSON as well.
|
||||
sessionParameters, err := json.Marshal(sd.SessionParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error encountered in reading session parameter details %s", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
string(sessionDetails),
|
||||
sd.Region,
|
||||
sessionCommand,
|
||||
sd.ProfileName,
|
||||
string(sessionParameters),
|
||||
sd.SessionEndpoint,
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/ssm"
|
||||
)
|
||||
|
||||
func TestStartSession(t *testing.T) {
|
||||
func TestSSMDriver_StartSession(t *testing.T) {
|
||||
tt := []struct {
|
||||
Name string
|
||||
PluginName string
|
||||
|
@ -17,13 +22,75 @@ func TestStartSession(t *testing.T) {
|
|||
for _, tc := range tt {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
driver := SSMDriver{PluginName: "someboguspluginname"}
|
||||
driver := SSMDriver{
|
||||
Region: "region",
|
||||
Session: new(ssm.StartSessionOutput),
|
||||
SessionParams: ssm.StartSessionInput{},
|
||||
SessionEndpoint: "endpoint",
|
||||
PluginName: tc.PluginName}
|
||||
|
||||
err := driver.StartSession("sessionData", "region", "profile", "params", "bogus-endpoint")
|
||||
ctx := context.TODO()
|
||||
err := driver.StartSession(ctx)
|
||||
|
||||
if tc.ErrorExpected && err == nil {
|
||||
t.Fatalf("Executing %q should have failed but instead no error was returned", tc.PluginName)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSMDriver_Args(t *testing.T) {
|
||||
tt := []struct {
|
||||
Name string
|
||||
Session *ssm.StartSessionOutput
|
||||
ProfileName string
|
||||
ErrorExpected bool
|
||||
}{
|
||||
{
|
||||
Name: "NilSession",
|
||||
ErrorExpected: true,
|
||||
},
|
||||
{
|
||||
Name: "NonNilSession",
|
||||
Session: new(ssm.StartSessionOutput),
|
||||
ErrorExpected: false,
|
||||
},
|
||||
{
|
||||
Name: "SessionWithProfileName",
|
||||
Session: new(ssm.StartSessionOutput),
|
||||
ProfileName: "default",
|
||||
ErrorExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
driver := SSMDriver{
|
||||
Region: "region",
|
||||
ProfileName: tc.ProfileName,
|
||||
Session: tc.Session,
|
||||
SessionParams: ssm.StartSessionInput{},
|
||||
SessionEndpoint: "amazon.com/sessions",
|
||||
}
|
||||
|
||||
args, err := driver.Args()
|
||||
if tc.ErrorExpected && err == nil {
|
||||
t.Fatalf("SSMDriver.Args with a %q should have failed but instead no error was returned", tc.Name)
|
||||
}
|
||||
|
||||
if tc.ErrorExpected {
|
||||
return
|
||||
}
|
||||
|
||||
// validate launch script
|
||||
expectedArgString := fmt.Sprintf(`{"SessionId":null,"StreamUrl":null,"TokenValue":null} %s StartSession %s {"DocumentName":null,"Parameters":null,"Target":null} %s`, driver.Region, driver.ProfileName, driver.SessionEndpoint)
|
||||
argString := strings.Join(args, " ")
|
||||
if argString != expectedArgString {
|
||||
t.Errorf("Expected launch script to be %q but got %q", expectedArgString, argString)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@ package common
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -18,61 +18,46 @@ import (
|
|||
)
|
||||
|
||||
type StepCreateSSMTunnel struct {
|
||||
AWSSession *session.Session
|
||||
DstPort int
|
||||
SSMAgentEnabled bool
|
||||
instanceId string
|
||||
ssmSession *ssm.StartSessionOutput
|
||||
AWSSession *session.Session
|
||||
Region string
|
||||
LocalPortNumber int
|
||||
RemotePortNumber int
|
||||
SSMAgentEnabled bool
|
||||
instanceId string
|
||||
session *ssm.StartSessionOutput
|
||||
}
|
||||
|
||||
// Run executes the Packer build step that creates a session tunnel.
|
||||
func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction {
|
||||
if !s.SSMAgentEnabled {
|
||||
return multistep.ActionContinue
|
||||
}
|
||||
|
||||
ui := state.Get("ui").(packer.Ui)
|
||||
// Find an available TCP port for our HTTP server
|
||||
l, err := net.ListenRangeConfig{
|
||||
Min: 8000,
|
||||
Max: 9000,
|
||||
Addr: "0.0.0.0",
|
||||
Network: "tcp",
|
||||
}.Listen(ctx)
|
||||
if err != nil {
|
||||
if err := s.ConfigureLocalHostPort(ctx); err != nil {
|
||||
err := fmt.Errorf("error finding an available port to initiate a session tunnel: %s", err)
|
||||
state.Put("error", err)
|
||||
ui.Error(err.Error())
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
dst, src := strconv.Itoa(s.DstPort), strconv.Itoa(l.Port)
|
||||
params := map[string][]*string{
|
||||
"portNumber": []*string{aws.String(dst)},
|
||||
"localPortNumber": []*string{aws.String(src)},
|
||||
}
|
||||
|
||||
instance, ok := state.Get("instance").(*ec2.Instance)
|
||||
if !ok {
|
||||
err := fmt.Errorf("error encountered in obtaining target instance id for SSM tunnel")
|
||||
err := fmt.Errorf("error encountered in obtaining target instance id for session tunnel")
|
||||
ui.Error(err.Error())
|
||||
state.Put("error", err)
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
s.instanceId = aws.StringValue(instance.InstanceId)
|
||||
ssmconn := ssm.New(s.AWSSession)
|
||||
input := ssm.StartSessionInput{
|
||||
DocumentName: aws.String("AWS-StartPortForwardingSession"),
|
||||
Parameters: params,
|
||||
Target: aws.String(s.instanceId),
|
||||
}
|
||||
|
||||
ui.Message(fmt.Sprintf("Starting PortForwarding session to instance %q on local port %q to remote port %q", s.instanceId, src, dst))
|
||||
log.Printf("Starting PortForwarding session to instance %q on local port %q to remote port %q", s.instanceId, s.LocalPortNumber, s.RemotePortNumber)
|
||||
input := s.BuildTunnelInputForInstance(s.instanceId)
|
||||
ssmconn := ssm.New(s.AWSSession)
|
||||
var output *ssm.StartSessionOutput
|
||||
err = retry.Config{
|
||||
err := retry.Config{
|
||||
ShouldRetry: func(err error) bool { return isAWSErr(err, "TargetNotConnected", "") },
|
||||
RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear,
|
||||
}.Run(ctx, func(ctx context.Context) error {
|
||||
}.Run(ctx, func(ctx context.Context) (err error) {
|
||||
output, err = ssmconn.StartSessionWithContext(ctx, &input)
|
||||
return err
|
||||
})
|
||||
|
@ -83,56 +68,82 @@ func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag)
|
|||
state.Put("error", err)
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
s.ssmSession = output
|
||||
|
||||
// AWS session-manager-plugin requires a valid session be passed in JSON
|
||||
sessionDetails, err := json.Marshal(s.ssmSession)
|
||||
if err != nil {
|
||||
ui.Error(err.Error())
|
||||
state.Put("error encountered in reading session details", err)
|
||||
return multistep.ActionHalt
|
||||
driver := SSMDriver{
|
||||
Region: s.Region,
|
||||
Session: output,
|
||||
SessionParams: input,
|
||||
SessionEndpoint: ssmconn.Endpoint,
|
||||
}
|
||||
|
||||
sessionParameters, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
ui.Error(err.Error())
|
||||
state.Put("error encountered in reading session parameter details", err)
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
// Stop listening on selected port so that the AWS session-manager-plugin can use it.
|
||||
// The port is closed right before we start the session to avoid two Packer builds from getting the same port - fingers-crossed
|
||||
l.Close()
|
||||
|
||||
driver := SSMDriver{Ui: ui, Ctx: ctx}
|
||||
// sessionDetails, region, "StartSession", profile, paramJson, endpoint
|
||||
region := aws.StringValue(s.AWSSession.Config.Region)
|
||||
// how to best get Profile name
|
||||
if err := driver.StartSession(string(sessionDetails), region, "", string(sessionParameters), ssmconn.Endpoint); err != nil {
|
||||
err = fmt.Errorf("error encountered in establishing a tunnel with the session-manager-plugin: %s", err)
|
||||
if err := driver.StartSession(ctx); err != nil {
|
||||
err = fmt.Errorf("error encountered in establishing a tunnel %s", err)
|
||||
ui.Error(err.Error())
|
||||
state.Put("error", err)
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
ui.Message(fmt.Sprintf("PortForwarding session to instance %q established!", s.instanceId))
|
||||
state.Put("sessionPort", l.Port)
|
||||
ui.Message(fmt.Sprintf("PortForwarding session tunnel to instance %q established!", s.instanceId))
|
||||
state.Put("sessionPort", s.LocalPortNumber)
|
||||
|
||||
return multistep.ActionContinue
|
||||
}
|
||||
|
||||
// Cleanup terminates an active session on AWS, which in turn terminates the associated tunnel process running on the local machine.
|
||||
func (s *StepCreateSSMTunnel) Cleanup(state multistep.StateBag) {
|
||||
if s.ssmSession == nil {
|
||||
if s.session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ui := state.Get("ui").(packer.Ui)
|
||||
ssmconn := ssm.New(s.AWSSession)
|
||||
_, err := ssmconn.TerminateSession(&ssm.TerminateSessionInput{SessionId: s.ssmSession.SessionId})
|
||||
_, err := ssmconn.TerminateSession(&ssm.TerminateSessionInput{SessionId: s.session.SessionId})
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Error terminating SSM Session %q. Please terminate the session manually: %s",
|
||||
aws.StringValue(s.ssmSession.SessionId), err)
|
||||
aws.StringValue(s.session.SessionId), err)
|
||||
ui.Error(msg)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ConfigureLocalHostPort finds an available port on the localhost that can be used for the remote tunnel.
|
||||
// Defaults to using s.LocalPortNumber if it is set.
|
||||
func (s *StepCreateSSMTunnel) ConfigureLocalHostPort(ctx context.Context) error {
|
||||
if s.LocalPortNumber != 0 {
|
||||
return nil
|
||||
}
|
||||
// Find an available TCP port for our HTTP server
|
||||
l, err := net.ListenRangeConfig{
|
||||
Min: 8000,
|
||||
Max: 9000,
|
||||
Addr: "0.0.0.0",
|
||||
Network: "tcp",
|
||||
}.Listen(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.LocalPortNumber = l.Port
|
||||
// Stop listening on selected port so that the AWS session-manager-plugin can use it.
|
||||
// The port is closed right before we start the session to avoid two Packer builds from getting the same port - fingers-crossed
|
||||
l.Close()
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *StepCreateSSMTunnel) BuildTunnelInputForInstance(instance string) ssm.StartSessionInput {
|
||||
dst, src := strconv.Itoa(s.RemotePortNumber), strconv.Itoa(s.LocalPortNumber)
|
||||
params := map[string][]*string{
|
||||
"portNumber": []*string{aws.String(dst)},
|
||||
"localPortNumber": []*string{aws.String(src)},
|
||||
}
|
||||
|
||||
input := ssm.StartSessionInput{
|
||||
DocumentName: aws.String("AWS-StartPortForwardingSession"),
|
||||
Parameters: params,
|
||||
Target: aws.String(instance),
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
)
|
||||
|
||||
func TestStepCreateSSMTunnel_BuildTunnelInputForInstance(t *testing.T) {
|
||||
step := StepCreateSSMTunnel{
|
||||
Region: "region",
|
||||
LocalPortNumber: 8001,
|
||||
RemotePortNumber: 22,
|
||||
SSMAgentEnabled: true,
|
||||
}
|
||||
|
||||
input := step.BuildTunnelInputForInstance("i-something")
|
||||
|
||||
target := aws.StringValue(input.Target)
|
||||
if target != "i-something" {
|
||||
t.Errorf("input should contain instance id as target but it got %q", target)
|
||||
}
|
||||
|
||||
params := map[string][]*string{
|
||||
"portNumber": []*string{aws.String("22")},
|
||||
"localPortNumber": []*string{aws.String("8001")},
|
||||
}
|
||||
if !reflect.DeepEqual(input.Parameters, params) {
|
||||
t.Errorf("input should contain the expected port parameters but it got %v", input.Parameters)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestStepCreateSSMTunnel_ConfigureLocalHostPort(t *testing.T) {
|
||||
tun := StepCreateSSMTunnel{}
|
||||
|
||||
ctx := context.TODO()
|
||||
if err := tun.ConfigureLocalHostPort(ctx); err != nil {
|
||||
t.Errorf("failed to configure a port on localhost")
|
||||
}
|
||||
|
||||
if tun.LocalPortNumber == 0 {
|
||||
t.Errorf("failed to configure a port on localhost")
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue