Check for closed channels as opposed to using a separate closeRetry channel

This commit is contained in:
Wilken Rivera 2020-10-07 11:26:27 -04:00
parent aa73cc7d7e
commit eb11009e2a
1 changed files with 11 additions and 13 deletions

View File

@ -39,7 +39,6 @@ type SSMDriver struct {
retryConnection chan bool retryConnection chan bool
retryAfterTermination chan bool retryAfterTermination chan bool
closeRetryPolling chan bool
wg sync.WaitGroup wg sync.WaitGroup
} }
@ -64,7 +63,6 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
d.retryConnection = make(chan bool, 1) d.retryConnection = make(chan bool, 1)
d.retryAfterTermination = make(chan bool, 1) d.retryAfterTermination = make(chan bool, 1)
d.closeRetryPolling = make(chan bool, 1)
// Starts go routine that will keep listening to channels and retry the session creation/connection when needed. // Starts go routine that will keep listening to channels and retry the session creation/connection when needed.
// The log polling process will add data to the channels whenever a retryable error happens to the session or if it's terminated. // The log polling process will add data to the channels whenever a retryable error happens to the session or if it's terminated.
go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) { go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) {
@ -72,11 +70,10 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case ok := <-d.closeRetryPolling: case r, ok := <-driver.retryAfterTermination:
if ok { if !ok {
return return
} }
case r := <-driver.retryAfterTermination:
if r { if r {
d.wg.Wait() d.wg.Wait()
log.Printf("[DEBUG] Restablishing SSM connection") log.Printf("[DEBUG] Restablishing SSM connection")
@ -84,7 +81,11 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu
// End this routine. Another routine will start. // End this routine. Another routine will start.
return return
} }
case r := <-driver.retryConnection: case r, ok := <-driver.retryConnection:
if !ok {
return
}
// Tunnel is still open an we want to try to reconnect // Tunnel is still open an we want to try to reconnect
if r { if r {
d.wg.Wait() d.wg.Wait()
@ -174,9 +175,9 @@ func (d *SSMDriver) openTunnelForSession(ctx context.Context) error {
if isRetryableError(output) { if isRetryableError(output) {
log.Printf("[ERROR] Retryable error - %s: %s", prefix, output) log.Printf("[ERROR] Retryable error - %s: %s", prefix, output)
d.retryConnection <- true d.retryConnection <- true
} else { continue
log.Printf("[ERROR] %s: %s", prefix, output)
} }
log.Printf("[ERROR] %s: %s", prefix, output)
} }
case output, ok := <-stdoutCh: case output, ok := <-stdoutCh:
if !ok { if !ok {
@ -227,9 +228,6 @@ func (d *SSMDriver) StopSession() error {
d.wg.Add(1) d.wg.Add(1)
defer d.wg.Done() defer d.wg.Done()
// Stop retry polling process to avoid unwanted retries at this point
d.closeRetryPolling <- true
if d.session == nil || d.session.SessionId == nil { if d.session == nil || d.session.SessionId == nil {
return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step", return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step",
aws.StringValue(d.sessionParams.Target)) aws.StringValue(d.sessionParams.Target))
@ -239,10 +237,10 @@ func (d *SSMDriver) StopSession() error {
if err != nil { if err != nil {
err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err) err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err)
} }
// Stop retry polling process to avoid unwanted retries at this point
close(d.closeRetryPolling)
close(d.retryAfterTermination) close(d.retryAfterTermination)
close(d.retryConnection) close(d.retryConnection)
return err return err
} }