diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTask.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTask.java
index 1ec1ec6b1c1..d20f4cb7d9b 100644
--- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTask.java
+++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTask.java
@@ -21,6 +21,7 @@ import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.transport.NetworkExceptionHelper;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.IllegalIndexShardStateException;
 import org.elasticsearch.index.shard.ShardId;
@@ -499,7 +500,7 @@ public abstract class ShardFollowNodeTask extends AllocatedPersistentTask {
 
     private void handleFailure(Exception e, AtomicInteger retryCounter, Runnable task) {
         assert e != null;
-        if (shouldRetry(params.getRemoteCluster(), e)) {
+        if (shouldRetry(e)) {
             if (isStopped() == false) {
                 // Only retry is the shard follow task is not stopped.
                 int currentRetry = retryCounter.incrementAndGet();
@@ -528,7 +529,7 @@ public abstract class ShardFollowNodeTask extends AllocatedPersistentTask {
         return Math.min(backOffDelay, maxRetryDelayInMillis);
     }
 
-    static boolean shouldRetry(String remoteCluster, Exception e) {
+    static boolean shouldRetry(final Exception e) {
         if (NetworkExceptionHelper.isConnectException(e)) {
             return true;
         } else if (NetworkExceptionHelper.isCloseConnectionException(e)) {
@@ -546,7 +547,8 @@ public abstract class ShardFollowNodeTask extends AllocatedPersistentTask {
             actual instanceof IndexClosedException || // If follow index is closed
             actual instanceof ConnectTransportException ||
             actual instanceof NodeClosedException ||
-            actual instanceof NoSuchRemoteClusterException;
+            actual instanceof NoSuchRemoteClusterException ||
+            actual instanceof EsRejectedExecutionException;
     }
 
     // These methods are protected for testing purposes:
diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java
index 34883d61c6b..cf6b161bc01 100644
--- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java
+++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java
@@ -512,7 +512,7 @@ public class ShardFollowTasksExecutor extends PersistentTasksExecutor<ShardFollo
                 return;
             }
 
-            if (ShardFollowNodeTask.shouldRetry(params.getRemoteCluster(), e)) {
+            if (ShardFollowNodeTask.shouldRetry(e)) {
                 logger.debug(new ParameterizedMessage("failed to fetch follow shard global {} checkpoint and max sequence number",
                     shardFollowNodeTask), e);
                 threadPool.schedule(() -> nodeOperation(task, params, state), params.getMaxRetryDelay(), Ccr.CCR_THREAD_POOL_NAME);
diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTaskTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTaskTests.java
index ef1dc43869a..7b336e7eeab 100644
--- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTaskTests.java
+++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowNodeTaskTests.java
@@ -12,6 +12,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardNotFoundException;
@@ -257,8 +258,16 @@ public class ShardFollowNodeTaskTests extends ESTestCase {
         startTask(task, 63, -1);
 
         int max = randomIntBetween(1, 30);
+        final Exception[] exceptions = new Exception[max];
         for (int i = 0; i < max; i++) {
-            readFailures.add(new ShardNotFoundException(new ShardId("leader_index", "", 0)));
+            final Exception exception;
+            if (randomBoolean()) {
+                exception = new ShardNotFoundException(new ShardId("leader_index", "", 0));
+            } else {
+                exception = new EsRejectedExecutionException("leader_index rejected");
+            }
+            exceptions[i] = exception;
+            readFailures.add(exception);
         }
         mappingVersions.add(1L);
         leaderGlobalCheckpoints.add(63L);
@@ -274,10 +283,17 @@ public class ShardFollowNodeTaskTests extends ESTestCase {
                 final Map.Entry<Long, Tuple<Integer, ElasticsearchException>> entry = status.readExceptions().entrySet().iterator().next();
                 assertThat(entry.getValue().v1(), equalTo(Math.toIntExact(retryCounter.get())));
                 assertThat(entry.getKey(), equalTo(0L));
-                assertThat(entry.getValue().v2(), instanceOf(ShardNotFoundException.class));
-                final ShardNotFoundException shardNotFoundException = (ShardNotFoundException) entry.getValue().v2();
-                assertThat(shardNotFoundException.getShardId().getIndexName(), equalTo("leader_index"));
-                assertThat(shardNotFoundException.getShardId().getId(), equalTo(0));
+                if (exceptions[Math.toIntExact(retryCounter.get()) - 1] instanceof ShardNotFoundException) {
+                    assertThat(entry.getValue().v2(), instanceOf(ShardNotFoundException.class));
+                    final ShardNotFoundException shardNotFoundException = (ShardNotFoundException) entry.getValue().v2();
+                    assertThat(shardNotFoundException.getShardId().getIndexName(), equalTo("leader_index"));
+                    assertThat(shardNotFoundException.getShardId().getId(), equalTo(0));
+                } else {
+                    assertThat(entry.getValue().v2().getCause(), instanceOf(EsRejectedExecutionException.class));
+                    final EsRejectedExecutionException rejectedExecutionException =
+                        (EsRejectedExecutionException) entry.getValue().v2().getCause();
+                    assertThat(rejectedExecutionException.getMessage(), equalTo("leader_index rejected"));
+                }
             }
             retryCounter.incrementAndGet();
         };