add retry terminated session chan

This commit is contained in:
sylviamoss 2020-09-29 12:31:10 +02:00 committed by Wilken Rivera
parent 8e3f3e514c
commit 1f62249097
1 changed files with 46 additions and 20 deletions

View File

@ -38,6 +38,7 @@ type SSMDriver struct {
pluginCmdFunc func(context.Context) error
retryConnection chan bool
retryAfterTermination chan bool
wg sync.WaitGroup
}
@ -55,35 +56,42 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
defer d.wg.Done()
log.Printf("Starting PortForwarding session to instance %q", aws.StringValue(input.Target))
var output *ssm.StartSessionOutput
err := retry.Config{
ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") },
RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear,
}.Run(ctx, func(ctx context.Context) (err error) {
output, err = d.SvcClient.StartSessionWithContext(ctx, &input)
return err
})
output, err := d.StartSessionWithContext(ctx, input)
if err != nil {
return nil, fmt.Errorf("error encountered in starting session for instance %q: %s", aws.StringValue(input.Target), err)
}
d.retryConnection = make(chan bool, 1)
d.retryAfterTermination = make(chan bool, 1)
// Starts go routine that will keep listening to a retry channel and retry the session creation when needed.
// The log loop will add data to the retry channel whenever a retryable error happens to session.
// The log polling process will add data to the retry channel whenever a retryable error happens to session.
go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) {
for {
select {
case <-ctx.Done():
return
case <-driver.retryConnection:
case r := <-driver.retryAfterTermination:
if r {
d.wg.Wait()
log.Printf("[DEBUG] Restablishing SSM connection")
_, _ = driver.StartSession(ctx, input)
// Close channels and end goroutine. Another goroutine will start
// and the channels wil be reopened.
close(driver.retryConnection)
close(driver.retryAfterTermination)
return
}
case r := <-driver.retryConnection:
if r {
d.wg.Wait()
log.Printf("[DEBUG] Retrying to establish SSM connection")
_, err := driver.StartSession(ctx, input)
_, err := driver.StartSessionWithContext(ctx, input)
if err != nil {
return
}
}
}
}
}(ctx, d, input)
d.session = output
@ -100,6 +108,20 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
return d.session, nil
}
func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.StartSessionInput) (*ssm.StartSessionOutput, error) {
d.wg.Add(1)
defer d.wg.Done()
var output *ssm.StartSessionOutput
err := retry.Config{
ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") },
RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear,
}.Run(ctx, func(ctx context.Context) (err error) {
output, err = d.SvcClient.StartSessionWithContext(ctx, &input)
return err
})
return output, err
}
func (d *SSMDriver) openTunnelForSession(ctx context.Context) error {
args, err := d.Args()
if err != nil {
@ -163,6 +185,7 @@ func (d *SSMDriver) openTunnelForSession(ctx context.Context) error {
if stdoutCh == nil && stderrCh == nil {
log.Printf("[DEBUG] %s: %s", prefix, "active session has been terminated; stopping all log polling processes.")
d.retryAfterTermination <- true
return
}
}
@ -198,10 +221,6 @@ func (d *SSMDriver) StopSession() error {
d.wg.Add(1)
defer d.wg.Done()
if d.retryConnection != nil {
close(d.retryConnection)
}
if d.session == nil || d.session.SessionId == nil {
return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step",
aws.StringValue(d.sessionParams.Target))
@ -211,6 +230,13 @@ func (d *SSMDriver) StopSession() error {
if err != nil {
err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err)
}
if d.retryConnection != nil {
close(d.retryConnection)
}
if d.retryAfterTermination != nil {
close(d.retryAfterTermination)
}
return err
}