diff --git a/builder/amazon/common/ssm_driver.go b/builder/amazon/common/ssm_driver.go index 4eb62c931..86c5d0d16 100644 --- a/builder/amazon/common/ssm_driver.go +++ b/builder/amazon/common/ssm_driver.go @@ -39,6 +39,7 @@ type SSMDriver struct { retryConnection chan bool retryAfterTermination chan bool + closeRetryPolling chan bool wg sync.WaitGroup } @@ -63,25 +64,28 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu d.retryConnection = make(chan bool, 1) d.retryAfterTermination = make(chan bool, 1) - // Starts go routine that will keep listening to a retry channel and retry the session creation when needed. - // The log polling process will add data to the retry channel whenever a retryable error happens to session. + d.closeRetryPolling = make(chan bool, 1) + // Starts go routine that will keep listening to channels and retry the session creation/connection when needed. + // The log polling process will add data to the channels whenever a retryable error happens to the session or if it's terminated. go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) { for { select { case <-ctx.Done(): return + case ok := <-d.closeRetryPolling: + if ok { + return + } case r := <-driver.retryAfterTermination: if r { d.wg.Wait() log.Printf("[DEBUG] Restablishing SSM connection") _, _ = driver.StartSession(ctx, input) - // Close channels and end goroutine. Another goroutine will start - // and the channels wil be reopened. - close(driver.retryConnection) - close(driver.retryAfterTermination) + // End this routine. Another routine will start. return } case r := <-driver.retryConnection: + // Tunnel is still open an we want to try to reconnect if r { d.wg.Wait() log.Printf("[DEBUG] Retrying to establish SSM connection") @@ -111,6 +115,7 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.StartSessionInput) (*ssm.StartSessionOutput, error) { d.wg.Add(1) defer d.wg.Done() + var output *ssm.StartSessionOutput err := retry.Config{ ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") }, @@ -119,6 +124,7 @@ func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.Start output, err = d.SvcClient.StartSessionWithContext(ctx, &input) return err }) + return output, err } @@ -221,6 +227,9 @@ func (d *SSMDriver) StopSession() error { d.wg.Add(1) defer d.wg.Done() + // Stop retry polling process to avoid unwanted retries at this point + d.closeRetryPolling <- true + if d.session == nil || d.session.SessionId == nil { return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step", aws.StringValue(d.sessionParams.Target)) @@ -231,12 +240,9 @@ func (d *SSMDriver) StopSession() error { err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err) } - if d.retryConnection != nil { - close(d.retryConnection) - } - if d.retryAfterTermination != nil { - close(d.retryAfterTermination) - } + close(d.closeRetryPolling) + close(d.retryAfterTermination) + close(d.retryConnection) return err }