simplify code

This commit is contained in:
Adrien Delorme 2020-10-29 12:18:41 +01:00
parent b058de072a
commit f329cb5b93
2 changed files with 4 additions and 50 deletions

View File

@ -1,34 +0,0 @@
package common
import (
"context"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
)
const (
sessionManagerPluginName string = "session-manager-plugin"
//sessionCommand is the AWS-SDK equivalent to the command you would specify to `aws ssm ...`
sessionCommand string = "StartSession"
)
type SSMDriverConfig struct {
SvcClient ssmiface.SSMAPI
Region string
ProfileName string
SvcEndpoint string
}
type SSMDriver struct {
SSMDriverConfig
session *ssm.StartSessionOutput
sessionParams ssm.StartSessionInput
pluginCmdFunc func(context.Context) error
}
func NewSSMDriver(config SSMDriverConfig) *SSMDriver {
d := SSMDriver{SSMDriverConfig: config}
return &d
}

View File

@ -24,7 +24,6 @@ type StepCreateSSMTunnel struct {
SSMAgentEnabled bool SSMAgentEnabled bool
instanceId string instanceId string
PauseBeforeSSM time.Duration PauseBeforeSSM time.Duration
driver *SSMDriver
stopSSMCommand func() stopSSMCommand func()
} }
@ -65,16 +64,6 @@ func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag)
} }
s.instanceId = aws.StringValue(instance.InstanceId) s.instanceId = aws.StringValue(instance.InstanceId)
if s.driver == nil {
ssmconn := ssm.New(s.AWSSession)
cfg := SSMDriverConfig{
SvcClient: ssmconn,
Region: s.Region,
SvcEndpoint: ssmconn.Endpoint,
}
driver := SSMDriver{SSMDriverConfig: cfg}
s.driver = &driver
}
state.Put("sessionPort", s.LocalPortNumber) state.Put("sessionPort", s.LocalPortNumber)
input := s.BuildTunnelInputForInstance(s.instanceId) input := s.BuildTunnelInputForInstance(s.instanceId)
@ -83,16 +72,15 @@ func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag)
s.stopSSMCommand = ssmCancel s.stopSSMCommand = ssmCancel
go func() { go func() {
ssmconn := ssm.New(s.AWSSession)
err := pssm.Session{ err := pssm.Session{
SvcClient: s.driver.SvcClient, SvcClient: ssmconn,
Input: input, Input: input,
Region: s.driver.Region, Region: s.Region,
}.Start(ssmCtx, ui) }.Start(ssmCtx, ui)
if err != nil { if err != nil {
err = fmt.Errorf("error encountered in establishing a tunnel %s", err) ui.Error(fmt.Sprintf("ssm error: %s", err))
ui.Error(err.Error())
state.Put("error", err)
} }
}() }()