add close chan to avoid unwanted retries

This commit is contained in:
sylviamoss 2020-09-29 16:15:47 +02:00 committed by Wilken Rivera
parent 1f62249097
commit aa73cc7d7e
1 changed files with 18 additions and 12 deletions

View File

@ -39,6 +39,7 @@ type SSMDriver struct {
retryConnection chan bool retryConnection chan bool
retryAfterTermination chan bool retryAfterTermination chan bool
closeRetryPolling chan bool
wg sync.WaitGroup wg sync.WaitGroup
} }
@ -63,25 +64,28 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
d.retryConnection = make(chan bool, 1) d.retryConnection = make(chan bool, 1)
d.retryAfterTermination = make(chan bool, 1) d.retryAfterTermination = make(chan bool, 1)
// Starts go routine that will keep listening to a retry channel and retry the session creation when needed. d.closeRetryPolling = make(chan bool, 1)
// The log polling process will add data to the retry channel whenever a retryable error happens to session. // Starts go routine that will keep listening to channels and retry the session creation/connection when needed.
// The log polling process will add data to the channels whenever a retryable error happens to the session or if it's terminated.
go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) { go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case ok := <-d.closeRetryPolling:
if ok {
return
}
case r := <-driver.retryAfterTermination: case r := <-driver.retryAfterTermination:
if r { if r {
d.wg.Wait() d.wg.Wait()
log.Printf("[DEBUG] Restablishing SSM connection") log.Printf("[DEBUG] Restablishing SSM connection")
_, _ = driver.StartSession(ctx, input) _, _ = driver.StartSession(ctx, input)
// Close channels and end goroutine. Another goroutine will start // End this routine. Another routine will start.
// and the channels wil be reopened.
close(driver.retryConnection)
close(driver.retryAfterTermination)
return return
} }
case r := <-driver.retryConnection: case r := <-driver.retryConnection:
// Tunnel is still open an we want to try to reconnect
if r { if r {
d.wg.Wait() d.wg.Wait()
log.Printf("[DEBUG] Retrying to establish SSM connection") log.Printf("[DEBUG] Retrying to establish SSM connection")
@ -111,6 +115,7 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.StartSessionInput) (*ssm.StartSessionOutput, error) { func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.StartSessionInput) (*ssm.StartSessionOutput, error) {
d.wg.Add(1) d.wg.Add(1)
defer d.wg.Done() defer d.wg.Done()
var output *ssm.StartSessionOutput var output *ssm.StartSessionOutput
err := retry.Config{ err := retry.Config{
ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") }, ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") },
@ -119,6 +124,7 @@ func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.Start
output, err = d.SvcClient.StartSessionWithContext(ctx, &input) output, err = d.SvcClient.StartSessionWithContext(ctx, &input)
return err return err
}) })
return output, err return output, err
} }
@ -221,6 +227,9 @@ func (d *SSMDriver) StopSession() error {
d.wg.Add(1) d.wg.Add(1)
defer d.wg.Done() defer d.wg.Done()
// Stop retry polling process to avoid unwanted retries at this point
d.closeRetryPolling <- true
if d.session == nil || d.session.SessionId == nil { if d.session == nil || d.session.SessionId == nil {
return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step", return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step",
aws.StringValue(d.sessionParams.Target)) aws.StringValue(d.sessionParams.Target))
@ -231,12 +240,9 @@ func (d *SSMDriver) StopSession() error {
err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err) err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err)
} }
if d.retryConnection != nil { close(d.closeRetryPolling)
close(d.retryConnection)
}
if d.retryAfterTermination != nil {
close(d.retryAfterTermination) close(d.retryAfterTermination)
} close(d.retryConnection)
return err return err
} }