simplify things a bit more

This commit is contained in:
Adrien Delorme 2020-10-29 13:11:07 +01:00
parent aef3d24213
commit a4bd744955
2 changed files with 29 additions and 30 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log" "log"
"os/exec" "os/exec"
"strconv"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -18,19 +19,36 @@ import (
) )
type Session struct { type Session struct {
SvcClient ssmiface.SSMAPI SvcClient ssmiface.SSMAPI
Region string Region string
Input ssm.StartSessionInput InstanceID string
LocalPort, RemotePort int
}
func (s Session) buildTunnelInput() *ssm.StartSessionInput {
portNumber, localPortNumber := strconv.Itoa(s.RemotePort), strconv.Itoa(s.LocalPort)
params := map[string][]*string{
"portNumber": []*string{aws.String(portNumber)},
"localPortNumber": []*string{aws.String(localPortNumber)},
}
return &ssm.StartSessionInput{
DocumentName: aws.String("AWS-StartPortForwardingSession"),
Parameters: params,
Target: aws.String(s.InstanceID),
}
} }
// getCommand return a valid ordered set of arguments to pass to the driver command. // getCommand return a valid ordered set of arguments to pass to the driver command.
func (s Session) getCommand(ctx context.Context) ([]string, string, error) { func (s Session) getCommand(ctx context.Context) ([]string, string, error) {
input := s.buildTunnelInput()
var session *ssm.StartSessionOutput var session *ssm.StartSessionOutput
err := retry.Config{ err := retry.Config{
ShouldRetry: func(err error) bool { return awserrors.Matches(err, "TargetNotConnected", "") }, ShouldRetry: func(err error) bool { return awserrors.Matches(err, "TargetNotConnected", "") },
RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear, RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear,
}.Run(ctx, func(ctx context.Context) (err error) { }.Run(ctx, func(ctx context.Context) (err error) {
session, err = s.SvcClient.StartSessionWithContext(ctx, &s.Input) session, err = s.SvcClient.StartSessionWithContext(ctx, input)
return err return err
}) })
@ -49,7 +67,7 @@ func (s Session) getCommand(ctx context.Context) ([]string, string, error) {
} }
// AWS session-manager-plugin requires the parameters used in the session to be passed in JSON as well. // AWS session-manager-plugin requires the parameters used in the session to be passed in JSON as well.
sessionParameters, err := json.Marshal(s.Input) sessionParameters, err := json.Marshal(input)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("error encountered in reading session parameter details %s", err) return nil, "", fmt.Errorf("error encountered in reading session parameter details %s", err)
} }
@ -73,7 +91,7 @@ func (s Session) getCommand(ctx context.Context) ([]string, string, error) {
// created from calling StartSession. // created from calling StartSession.
func (s Session) Start(ctx context.Context, ui packer.Ui) error { func (s Session) Start(ctx context.Context, ui packer.Ui) error {
for ctx.Err() == nil { for ctx.Err() == nil {
log.Printf("ssm: Starting PortForwarding session to instance %q", *s.Input.Target) log.Printf("ssm: Starting PortForwarding session to instance %s", s.InstanceID)
args, sessionID, err := s.getCommand(ctx) args, sessionID, err := s.getCommand(ctx)
if sessionID != "" { if sessionID != "" {
defer func() { defer func() {

View File

@ -3,7 +3,6 @@ package common
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -22,7 +21,6 @@ type StepCreateSSMTunnel struct {
LocalPortNumber int LocalPortNumber int
RemotePortNumber int RemotePortNumber int
SSMAgentEnabled bool SSMAgentEnabled bool
instanceId string
PauseBeforeSSM time.Duration PauseBeforeSSM time.Duration
stopSSMCommand func() stopSSMCommand func()
} }
@ -62,21 +60,20 @@ func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag)
state.Put("error", err) state.Put("error", err)
return multistep.ActionHalt return multistep.ActionHalt
} }
s.instanceId = aws.StringValue(instance.InstanceId)
state.Put("sessionPort", s.LocalPortNumber) state.Put("sessionPort", s.LocalPortNumber)
input := s.BuildTunnelInputForInstance(s.instanceId)
ssmCtx, ssmCancel := context.WithCancel(ctx) ssmCtx, ssmCancel := context.WithCancel(ctx)
s.stopSSMCommand = ssmCancel s.stopSSMCommand = ssmCancel
go func() { go func() {
ssmconn := ssm.New(s.AWSSession) ssmconn := ssm.New(s.AWSSession)
err := pssm.Session{ err := pssm.Session{
SvcClient: ssmconn, SvcClient: ssmconn,
Input: input, InstanceID: aws.StringValue(instance.InstanceId),
Region: s.Region, RemotePort: s.RemotePortNumber,
LocalPort: s.LocalPortNumber,
Region: s.Region,
}.Start(ssmCtx, ui) }.Start(ssmCtx, ui)
if err != nil { if err != nil {
@ -127,19 +124,3 @@ func (s *StepCreateSSMTunnel) ConfigureLocalHostPort(ctx context.Context) error
return nil return nil
} }
func (s *StepCreateSSMTunnel) BuildTunnelInputForInstance(instance string) ssm.StartSessionInput {
dst, src := strconv.Itoa(s.RemotePortNumber), strconv.Itoa(s.LocalPortNumber)
params := map[string][]*string{
"portNumber": []*string{aws.String(dst)},
"localPortNumber": []*string{aws.String(src)},
}
input := ssm.StartSessionInput{
DocumentName: aws.String("AWS-StartPortForwardingSession"),
Parameters: params,
Target: aws.String(instance),
}
return input
}