From 87076c32e219bd501be4a0dc1968fe7a4fe7b537 Mon Sep 17 00:00:00 2001
From: Tanguy Leroux <tlrx.dev@gmail.com>
Date: Tue, 6 Oct 2020 18:37:05 +0200
Subject: [PATCH] Determine shard size before allocating shards recovering from
 snapshots (#61906) (#63337)

Determines the shard size of shards before allocating shards that are
recovering from snapshots. It ensures during shard allocation that the
target node that is selected as recovery target will have enough free
disk space for the recovery event. This applies to regular restores,
CCR bootstrap from remote, as well as mounting searchable snapshots.

The InternalSnapshotInfoService is responsible for fetching snapshot
shard sizes from repositories. It provides a getShardSize() method
to other components of the system that can be used to retrieve the
latest known shard size. If the latest snapshot shard size retrieval
failed, the getShardSize() returns
ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE. While
we'd like a better way to handle such failures, returning this value
allows to keep the existing behavior for now.

Note that this PR does not address an issues (we already have today)
where a replica is being allocated without knowing how much disk
space is being used by the primary.

Co-authored-by: Yannick Welsch <yannick@welsch.lu>
---
 .../routing/allocation/Allocators.java        |   4 +-
 .../decider/DiskThresholdDeciderIT.java       |  92 ++++-
 ...ansportClusterAllocationExplainAction.java |   8 +-
 .../elasticsearch/cluster/ClusterInfo.java    |   2 -
 .../elasticsearch/cluster/ClusterModule.java  |   5 +-
 .../routing/allocation/AllocationService.java |  22 +-
 .../routing/allocation/RoutingAllocation.java |  10 +-
 .../allocator/BalancedShardsAllocator.java    |   5 +-
 .../decider/DiskThresholdDecider.java         |  19 +-
 .../common/settings/ClusterSettings.java      |   2 +
 .../gateway/BaseGatewayShardAllocator.java    |  16 +-
 .../gateway/PrimaryShardAllocator.java        |  11 +
 .../java/org/elasticsearch/node/Node.java     |  15 +-
 .../snapshots/EmptySnapshotsInfoService.java  |  31 ++
 .../InternalSnapshotsInfoService.java         | 386 ++++++++++++++++++
 .../snapshots/SnapshotShardSizeInfo.java      |  53 +++
 .../snapshots/SnapshotsInfoService.java       |  25 ++
 .../ClusterAllocationExplainActionTests.java  |   6 +-
 .../cluster/reroute/ClusterRerouteTests.java  |   4 +-
 .../shrink/TransportResizeActionTests.java    |  13 +-
 .../cluster/ClusterModuleTests.java           |  13 +-
 .../health/ClusterStateHealthTests.java       |   2 +-
 .../MetadataCreateIndexServiceTests.java      |  10 +-
 .../allocation/AllocationCommandsTests.java   |   5 +-
 .../allocation/AllocationServiceTests.java    |   7 +-
 .../allocation/BalanceConfigurationTests.java |   3 +-
 .../allocation/BalancedSingleShardTests.java  |   3 +-
 .../DecisionsImpactOnClusterHealthTests.java  |   4 +-
 .../MaxRetryAllocationDeciderTests.java       |  12 +-
 .../NodeVersionAllocationDeciderTests.java    |  35 +-
 .../RandomAllocationDeciderTests.java         |   4 +-
 .../ResizeAllocationDeciderTests.java         |  12 +-
 .../allocation/SameShardRoutingTests.java     |   3 +-
 .../allocation/ThrottlingAllocationTests.java |  57 ++-
 .../decider/AllocationDecidersTests.java      |   2 +-
 .../decider/DiskThresholdDeciderTests.java    |  39 +-
 .../DiskThresholdDeciderUnitTests.java        |  14 +-
 .../EnableAllocationShortCircuitTests.java    |   4 +-
 .../decider/FilterAllocationDeciderTests.java |   8 +-
 ...storeInProgressAllocationDeciderTests.java |   2 +-
 .../gateway/GatewayServiceTests.java          |   4 +-
 .../gateway/PrimaryShardAllocatorTests.java   |  44 +-
 .../gateway/ReplicaShardAllocatorTests.java   |   7 +-
 .../indices/cluster/ClusterStateChanges.java  |   3 +-
 .../InternalSnapshotsInfoServiceTests.java    | 350 ++++++++++++++++
 .../snapshots/SnapshotResiliencyTests.java    |  21 +-
 .../cluster/ESAllocationTestCase.java         |  41 +-
 .../MockInternalClusterInfoService.java       |   3 +-
 .../xpack/ccr/repository/CcrRepository.java   |  23 +-
 ...PrimaryFollowerAllocationDeciderTests.java |   3 +-
 .../xpack/core/ilm/AllocationRoutedStep.java  |   2 +-
 .../core/ilm/SetSingleNodeAllocateStep.java   |   2 +-
 .../DataTierAllocationDeciderTests.java       |  32 +-
 .../SearchableSnapshotAllocator.java          |   6 +
 54 files changed, 1329 insertions(+), 180 deletions(-)
 create mode 100644 server/src/main/java/org/elasticsearch/snapshots/EmptySnapshotsInfoService.java
 create mode 100644 server/src/main/java/org/elasticsearch/snapshots/InternalSnapshotsInfoService.java
 create mode 100644 server/src/main/java/org/elasticsearch/snapshots/SnapshotShardSizeInfo.java
 create mode 100644 server/src/main/java/org/elasticsearch/snapshots/SnapshotsInfoService.java
 create mode 100644 server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java

diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/Allocators.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/Allocators.java
index 6e994e3fbf7..284f01b58a3 100644
--- a/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/Allocators.java
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/Allocators.java
@@ -35,6 +35,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.gateway.GatewayAllocator;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 
 import java.util.Collection;
 import java.util.Collections;
@@ -79,7 +80,8 @@ public final class Allocators {
             defaultAllocationDeciders(settings, clusterSettings),
             NoopGatewayAllocator.INSTANCE,
             new BalancedShardsAllocator(settings),
-            EmptyClusterInfoService.INSTANCE
+            EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE
         );
     }
 
diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java
index ce857b1bfb1..146eec61551 100644
--- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java
+++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java
@@ -23,6 +23,8 @@ import org.apache.lucene.mockfile.FilterFileStore;
 import org.apache.lucene.mockfile.FilterFileSystemProvider;
 import org.apache.lucene.mockfile.FilterPath;
 import org.apache.lucene.util.Constants;
+import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
+import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.cluster.ClusterInfoService;
@@ -32,6 +34,7 @@ import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.allocation.DiskThresholdSettings;
+import org.elasticsearch.cluster.routing.allocation.decider.EnableAllocationDecider.Rebalance;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.io.PathUtils;
@@ -44,6 +47,10 @@ import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.monitor.fs.FsService;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.repositories.fs.FsRepository;
+import org.elasticsearch.snapshots.RestoreInfo;
+import org.elasticsearch.snapshots.SnapshotInfo;
+import org.elasticsearch.snapshots.SnapshotState;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.InternalSettingsPlugin;
 import org.junit.After;
@@ -62,6 +69,7 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -141,29 +149,95 @@ public class DiskThresholdDeciderIT extends ESIntegTestCase {
         final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
         final Path dataNode0Path = internalCluster().getInstance(Environment.class, dataNodeName).dataFiles()[0];
 
-        createIndex("test", Settings.builder()
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        createIndex(indexName, Settings.builder()
                 .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
                 .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 6)
                 .put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms")
                 .build());
-        final long minShardSize = createReasonableSizedShards();
+        final long minShardSize = createReasonableSizedShards(indexName);
 
         // reduce disk size of node 0 so that no shards fit below the high watermark, forcing all shards onto the other data node
         // (subtract the translog size since the disk threshold decider ignores this and may therefore move the shard back again)
         fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES - 1L);
         refreshDiskUsage();
-        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id), empty()));
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), empty()));
 
         // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
         fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES + 1L);
         refreshDiskUsage();
-        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id), hasSize(1)));
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), hasSize(1)));
     }
 
-    private Set<ShardRouting> getShardRoutings(String nodeId) {
+    public void testRestoreSnapshotAllocationDoesNotExceedWatermark() throws Exception {
+        internalCluster().startMasterOnlyNode();
+        internalCluster().startDataOnlyNode();
+        final String dataNodeName = internalCluster().startDataOnlyNode();
+        ensureStableCluster(3);
+
+        assertAcked(client().admin().cluster().preparePutRepository("repo")
+            .setType(FsRepository.TYPE)
+            .setSettings(Settings.builder()
+                .put("location", randomRepoPath())
+                .put("compress", randomBoolean())));
+
+        final InternalClusterInfoService clusterInfoService
+            = (InternalClusterInfoService) internalCluster().getMasterNodeInstance(ClusterInfoService.class);
+        internalCluster().getMasterNodeInstance(ClusterService.class).addListener(event -> clusterInfoService.refresh());
+
+        final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
+        final Path dataNode0Path = internalCluster().getInstance(Environment.class, dataNodeName).dataFiles()[0];
+
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        createIndex(indexName, Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 6)
+            .put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms")
+            .build());
+        final long minShardSize = createReasonableSizedShards(indexName);
+
+        final CreateSnapshotResponse createSnapshotResponse = client().admin().cluster().prepareCreateSnapshot("repo", "snap")
+            .setWaitForCompletion(true).get();
+        final SnapshotInfo snapshotInfo = createSnapshotResponse.getSnapshotInfo();
+        assertThat(snapshotInfo.successfulShards(), is(snapshotInfo.totalShards()));
+        assertThat(snapshotInfo.state(), is(SnapshotState.SUCCESS));
+
+        assertAcked(client().admin().indices().prepareDelete(indexName).get());
+
+        // reduce disk size of node 0 so that no shards fit below the low watermark, forcing shards to be assigned to the other data node
+        fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES - 1L);
+        refreshDiskUsage();
+
+        assertAcked(client().admin().cluster().prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .put(EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey(), Rebalance.NONE.toString())
+                .build())
+            .get());
+
+        final RestoreSnapshotResponse restoreSnapshotResponse = client().admin().cluster().prepareRestoreSnapshot("repo", "snap")
+            .setWaitForCompletion(true).get();
+        final RestoreInfo restoreInfo = restoreSnapshotResponse.getRestoreInfo();
+        assertThat(restoreInfo.successfulShards(), is(snapshotInfo.totalShards()));
+        assertThat(restoreInfo.failedShards(), is(0));
+
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), empty()));
+
+        assertAcked(client().admin().cluster().prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .putNull(EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey())
+                .build())
+            .get());
+
+        // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
+        fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES + 1L);
+        refreshDiskUsage();
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), hasSize(1)));
+    }
+
+    private Set<ShardRouting> getShardRoutings(final String nodeId, final String indexName) {
         final Set<ShardRouting> shardRoutings = new HashSet<>();
         for (IndexShardRoutingTable indexShardRoutingTable : client().admin().cluster().prepareState().clear().setRoutingTable(true)
-                .get().getState().getRoutingTable().index("test")) {
+                .get().getState().getRoutingTable().index(indexName)) {
             for (ShardRouting shard : indexShardRoutingTable.shards()) {
                 assertThat(shard.state(), equalTo(ShardRoutingState.STARTED));
                 if (shard.currentNodeId().equals(nodeId)) {
@@ -177,17 +251,17 @@ public class DiskThresholdDeciderIT extends ESIntegTestCase {
     /**
      * Index documents until all the shards are at least WATERMARK_BYTES in size, and return the size of the smallest shard
      */
-    private long createReasonableSizedShards() throws InterruptedException {
+    private long createReasonableSizedShards(final String indexName) throws InterruptedException {
         while (true) {
             final IndexRequestBuilder[] indexRequestBuilders = new IndexRequestBuilder[scaledRandomIntBetween(100, 10000)];
             for (int i = 0; i < indexRequestBuilders.length; i++) {
-                indexRequestBuilders[i] = client().prepareIndex("test", "_doc").setSource("field", randomAlphaOfLength(10));
+                indexRequestBuilders[i] = client().prepareIndex(indexName, "_doc").setSource("field", randomAlphaOfLength(10));
             }
             indexRandom(true, indexRequestBuilders);
             forceMerge();
             refresh();
 
-            final ShardStats[] shardStatses = client().admin().indices().prepareStats("test")
+            final ShardStats[] shardStatses = client().admin().indices().prepareStats(indexName)
                     .clear().setStore(true).setTranslog(true).get().getShards();
             final long[] shardSizes = new long[shardStatses.length];
             for (ShardStats shardStats : shardStatses) {
diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java
index faf239a15b7..008044534bb 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java
@@ -42,6 +42,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 
@@ -58,6 +59,7 @@ public class TransportClusterAllocationExplainAction
     private static final Logger logger = LogManager.getLogger(TransportClusterAllocationExplainAction.class);
 
     private final ClusterInfoService clusterInfoService;
+    private final SnapshotsInfoService snapshotsInfoService;
     private final AllocationDeciders allocationDeciders;
     private final ShardsAllocator shardAllocator;
     private final AllocationService allocationService;
@@ -66,11 +68,13 @@ public class TransportClusterAllocationExplainAction
     public TransportClusterAllocationExplainAction(TransportService transportService, ClusterService clusterService,
                                                    ThreadPool threadPool, ActionFilters actionFilters,
                                                    IndexNameExpressionResolver indexNameExpressionResolver,
-                                                   ClusterInfoService clusterInfoService, AllocationDeciders allocationDeciders,
+                                                   ClusterInfoService clusterInfoService, SnapshotsInfoService snapshotsInfoService,
+                                                   AllocationDeciders allocationDeciders,
                                                    ShardsAllocator shardAllocator, AllocationService allocationService) {
         super(ClusterAllocationExplainAction.NAME, transportService, clusterService, threadPool, actionFilters,
             ClusterAllocationExplainRequest::new, indexNameExpressionResolver);
         this.clusterInfoService = clusterInfoService;
+        this.snapshotsInfoService = snapshotsInfoService;
         this.allocationDeciders = allocationDeciders;
         this.shardAllocator = shardAllocator;
         this.allocationService = allocationService;
@@ -97,7 +101,7 @@ public class TransportClusterAllocationExplainAction
         final RoutingNodes routingNodes = state.getRoutingNodes();
         final ClusterInfo clusterInfo = clusterInfoService.getClusterInfo();
         final RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, state,
-                clusterInfo, System.nanoTime());
+                clusterInfo, snapshotsInfoService.snapshotShardSizes(), System.nanoTime());
 
         ShardRouting shardRouting = findShardToExplain(request, allocation);
         logger.debug("explaining the allocation for [{}], found shard [{}]", request, shardRouting);
diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java b/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java
index 59159c6f788..1feba57197f 100644
--- a/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java
+++ b/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java
@@ -22,7 +22,6 @@ package org.elasticsearch.cluster;
 import com.carrotsearch.hppc.ObjectHashSet;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
 import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
-
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -33,7 +32,6 @@ import org.elasticsearch.common.xcontent.ToXContentFragment;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.store.StoreStats;
-
 import java.io.IOException;
 import java.util.Map;
 import java.util.Objects;
diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterModule.java b/server/src/main/java/org/elasticsearch/cluster/ClusterModule.java
index b043fc17b55..556948bb01e 100644
--- a/server/src/main/java/org/elasticsearch/cluster/ClusterModule.java
+++ b/server/src/main/java/org/elasticsearch/cluster/ClusterModule.java
@@ -76,6 +76,7 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.persistent.PersistentTasksNodeService;
 import org.elasticsearch.plugins.ClusterPlugin;
 import org.elasticsearch.script.ScriptMetadata;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskResultsService;
 
@@ -110,14 +111,14 @@ public class ClusterModule extends AbstractModule {
     final ShardsAllocator shardsAllocator;
 
     public ClusterModule(Settings settings, ClusterService clusterService, List<ClusterPlugin> clusterPlugins,
-                         ClusterInfoService clusterInfoService) {
+                         ClusterInfoService clusterInfoService, SnapshotsInfoService snapshotsInfoService) {
         this.clusterPlugins = clusterPlugins;
         this.deciderList = createAllocationDeciders(settings, clusterService.getClusterSettings(), clusterPlugins);
         this.allocationDeciders = new AllocationDeciders(deciderList);
         this.shardsAllocator = createShardsAllocator(settings, clusterService.getClusterSettings(), clusterPlugins);
         this.clusterService = clusterService;
         this.indexNameExpressionResolver = new IndexNameExpressionResolver();
-        this.allocationService = new AllocationService(allocationDeciders, shardsAllocator, clusterInfoService);
+        this.allocationService = new AllocationService(allocationDeciders, shardsAllocator, clusterInfoService, snapshotsInfoService);
     }
 
     public static List<Entry> getNamedWriteables() {
diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java
index 5ae5b9f42c5..65c88447ebd 100644
--- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java
+++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java
@@ -44,6 +44,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.gateway.GatewayAllocator;
 import org.elasticsearch.gateway.PriorityComparator;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -74,19 +75,22 @@ public class AllocationService {
     private Map<String, ExistingShardsAllocator> existingShardsAllocators;
     private final ShardsAllocator shardsAllocator;
     private final ClusterInfoService clusterInfoService;
+    private SnapshotsInfoService snapshotsInfoService;
 
     // only for tests that use the GatewayAllocator as the unique ExistingShardsAllocator
     public AllocationService(AllocationDeciders allocationDeciders, GatewayAllocator gatewayAllocator,
-                             ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService) {
-        this(allocationDeciders, shardsAllocator, clusterInfoService);
+                             ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService,
+                             SnapshotsInfoService snapshotsInfoService) {
+        this(allocationDeciders, shardsAllocator, clusterInfoService, snapshotsInfoService);
         setExistingShardsAllocators(Collections.singletonMap(GatewayAllocator.ALLOCATOR_NAME, gatewayAllocator));
     }
 
     public AllocationService(AllocationDeciders allocationDeciders, ShardsAllocator shardsAllocator,
-                             ClusterInfoService clusterInfoService) {
+                             ClusterInfoService clusterInfoService, SnapshotsInfoService snapshotsInfoService) {
         this.allocationDeciders = allocationDeciders;
         this.shardsAllocator = shardsAllocator;
         this.clusterInfoService = clusterInfoService;
+        this.snapshotsInfoService = snapshotsInfoService;
     }
 
     /**
@@ -113,7 +117,7 @@ public class AllocationService {
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         routingNodes.unassigned().shuffle();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         // as starting a primary relocation target can reinitialize replica shards, start replicas first
         startedShards = new ArrayList<>(startedShards);
         startedShards.sort(Comparator.comparing(ShardRouting::primary));
@@ -192,7 +196,7 @@ public class AllocationService {
         routingNodes.unassigned().shuffle();
         long currentNanoTime = currentNanoTime();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, tmpState,
-            clusterInfoService.getClusterInfo(), currentNanoTime);
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime);
 
         for (FailedShard failedShardEntry : failedShards) {
             ShardRouting shardToFail = failedShardEntry.getRoutingEntry();
@@ -246,7 +250,7 @@ public class AllocationService {
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         routingNodes.unassigned().shuffle();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
 
         // first, clear from the shards any node id they used to belong to that is now dead
         disassociateDeadNodes(allocation);
@@ -267,7 +271,7 @@ public class AllocationService {
      */
     public ClusterState adaptAutoExpandReplicas(ClusterState clusterState) {
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, clusterState.getRoutingNodes(), clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         final Map<Integer, List<String>> autoExpandReplicaChanges =
             AutoExpandReplicas.getAutoExpandReplicaChanges(clusterState.metadata(), allocation);
         if (autoExpandReplicaChanges.isEmpty()) {
@@ -361,7 +365,7 @@ public class AllocationService {
         // a consistent result of the effect the commands have on the routing
         // this allows systems to dry run the commands, see the resulting cluster state, and act on it
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         // don't short circuit deciders, we want a full explanation
         allocation.debugDecision(true);
         // we ignore disable allocation, because commands are explicit
@@ -392,7 +396,7 @@ public class AllocationService {
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         routingNodes.unassigned().shuffle();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, fixedClusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         reroute(allocation);
         if (fixedClusterState == clusterState && allocation.routingNodesChanged() == false) {
             return clusterState;
diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java
index 7ee91637b93..b232fed634f 100644
--- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java
+++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java
@@ -33,6 +33,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.snapshots.RestoreService.RestoreInProgressUpdater;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.HashMap;
 import java.util.HashSet;
@@ -63,6 +64,8 @@ public class RoutingAllocation {
 
     private final ClusterInfo clusterInfo;
 
+    private final SnapshotShardSizeInfo shardSizeInfo;
+
     private Map<ShardId, Set<String>> ignoredShardToNodes = null;
 
     private boolean ignoreDisable = false;
@@ -89,7 +92,7 @@ public class RoutingAllocation {
      * @param currentNanoTime the nano time to use for all delay allocation calculation (typically {@link System#nanoTime()})
      */
     public RoutingAllocation(AllocationDeciders deciders, RoutingNodes routingNodes, ClusterState clusterState, ClusterInfo clusterInfo,
-                             long currentNanoTime) {
+                             SnapshotShardSizeInfo shardSizeInfo, long currentNanoTime) {
         this.deciders = deciders;
         this.routingNodes = routingNodes;
         this.metadata = clusterState.metadata();
@@ -97,6 +100,7 @@ public class RoutingAllocation {
         this.nodes = clusterState.nodes();
         this.customs = clusterState.customs();
         this.clusterInfo = clusterInfo;
+        this.shardSizeInfo = shardSizeInfo;
         this.currentNanoTime = currentNanoTime;
     }
 
@@ -149,6 +153,10 @@ public class RoutingAllocation {
         return clusterInfo;
     }
 
+    public SnapshotShardSizeInfo snapshotShardSizeInfo() {
+        return shardSizeInfo;
+    }
+
     public <T extends ClusterState.Custom> T custom(String key) {
         return (T)customs.get(key);
     }
diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
index 586efd95678..0da8f8f976c 100644
--- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
+++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
@@ -810,7 +810,7 @@ public class BalancedShardsAllocator implements ShardsAllocator {
 
                         final long shardSize = DiskThresholdDecider.getExpectedShardSize(shard,
                             ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                            allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+                            allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(), allocation.routingTable());
                         shard = routingNodes.initializeShard(shard, minNode.getNodeId(), null, shardSize, allocation.changes());
                         minNode.addShard(shard);
                         if (!shard.primary()) {
@@ -832,7 +832,8 @@ public class BalancedShardsAllocator implements ShardsAllocator {
                             assert allocationDecision.getAllocationStatus() == AllocationStatus.DECIDERS_THROTTLED;
                             final long shardSize = DiskThresholdDecider.getExpectedShardSize(shard,
                                 ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                                allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+                                allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(),
+                                allocation.routingTable());
                             minNode.addShard(shard.initialize(minNode.getNodeId(), null, shardSize));
                         } else {
                             if (logger.isTraceEnabled()) {
diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java
index 3d6c2d433c8..965f38d87ed 100644
--- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java
+++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java
@@ -43,6 +43,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.List;
 import java.util.Set;
@@ -120,7 +121,7 @@ public class DiskThresholdDecider extends AllocationDecider {
             // if we don't yet know the actual path of the incoming shard then conservatively assume it's going to the path with the least
             // free space
             if (actualPath == null || actualPath.equals(dataPath)) {
-                totalSize += getExpectedShardSize(routing, 0L, clusterInfo, metadata, routingTable);
+                totalSize += getExpectedShardSize(routing, 0L, clusterInfo, null, metadata, routingTable);
             }
         }
 
@@ -132,7 +133,7 @@ public class DiskThresholdDecider extends AllocationDecider {
                     actualPath = clusterInfo.getDataPath(routing.cancelRelocation());
                 }
                 if (dataPath.equals(actualPath)) {
-                    totalSize -= getExpectedShardSize(routing, 0L, clusterInfo, metadata, routingTable);
+                    totalSize -= getExpectedShardSize(routing, 0L, clusterInfo, null, metadata, routingTable);
                 }
             }
         }
@@ -153,7 +154,7 @@ public class DiskThresholdDecider extends AllocationDecider {
         final double usedDiskThresholdLow = 100.0 - diskThresholdSettings.getFreeDiskThresholdLow();
         final double usedDiskThresholdHigh = 100.0 - diskThresholdSettings.getFreeDiskThresholdHigh();
 
-        // subtractLeavingShards is passed as false here, because they still use disk space, and therefore should we should be extra careful
+        // subtractLeavingShards is passed as false here, because they still use disk space, and therefore we should be extra careful
         // and take the size into account
         final DiskUsageWithRelocations usage = getDiskUsage(node, allocation, usages, false);
         // First, check that the node currently over the low watermark
@@ -270,7 +271,7 @@ public class DiskThresholdDecider extends AllocationDecider {
 
         // Secondly, check that allocating the shard to this node doesn't put it above the high watermark
         final long shardSize = getExpectedShardSize(shardRouting, 0L,
-            allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+            allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(), allocation.routingTable());
         assert shardSize >= 0 : shardSize;
         double freeSpaceAfterShard = freeDiskPercentageAfterShardAssigned(usage, shardSize);
         long freeBytesAfterShard = freeBytes - shardSize;
@@ -466,8 +467,9 @@ public class DiskThresholdDecider extends AllocationDecider {
      * Returns the expected shard size for the given shard or the default value provided if not enough information are available
      * to estimate the shards size.
      */
-    public static long getExpectedShardSize(ShardRouting shard, long defaultValue, ClusterInfo clusterInfo, Metadata metadata,
-                                            RoutingTable routingTable) {
+    public static long getExpectedShardSize(ShardRouting shard, long defaultValue, ClusterInfo clusterInfo,
+                                            SnapshotShardSizeInfo snapshotShardSizeInfo,
+                                            Metadata metadata, RoutingTable routingTable) {
         final IndexMetadata indexMetadata = metadata.getIndexSafe(shard.index());
         if (indexMetadata.getResizeSourceIndex() != null && shard.active() == false &&
             shard.recoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS) {
@@ -487,6 +489,11 @@ public class DiskThresholdDecider extends AllocationDecider {
             }
             return targetShardSize == 0 ? defaultValue : targetShardSize;
         } else {
+            if (shard.unassigned() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+                final Long shardSize = snapshotShardSizeInfo.getShardSize(shard);
+                assert shardSize != null : "no shard size provided for " + shard;
+                return shardSize;
+            }
             return clusterInfo.getShardSize(shard, defaultValue);
         }
     }
diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
index a3cdd406850..a27b6b05223 100644
--- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
+++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
@@ -116,6 +116,7 @@ import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
 import org.elasticsearch.search.fetch.subphase.highlight.FastVectorHighlighter;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.SnapshotsService;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.ProxyConnectionStrategy;
@@ -254,6 +255,7 @@ public final class ClusterSettings extends AbstractScopedSettings {
             ShardStateAction.FOLLOW_UP_REROUTE_PRIORITY_SETTING,
             InternalClusterInfoService.INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING,
             InternalClusterInfoService.INTERNAL_CLUSTER_INFO_TIMEOUT_SETTING,
+            InternalSnapshotsInfoService.INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING,
             DestructiveOperations.REQUIRES_NAME_SETTING,
             DiscoverySettings.PUBLISH_TIMEOUT_SETTING,
             DiscoverySettings.PUBLISH_DIFF_ENABLE_SETTING,
diff --git a/server/src/main/java/org/elasticsearch/gateway/BaseGatewayShardAllocator.java b/server/src/main/java/org/elasticsearch/gateway/BaseGatewayShardAllocator.java
index 30e6c200402..103292913e7 100644
--- a/server/src/main/java/org/elasticsearch/gateway/BaseGatewayShardAllocator.java
+++ b/server/src/main/java/org/elasticsearch/gateway/BaseGatewayShardAllocator.java
@@ -21,6 +21,7 @@ package org.elasticsearch.gateway;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
@@ -64,14 +65,25 @@ public abstract class BaseGatewayShardAllocator {
         if (allocateUnassignedDecision.getAllocationDecision() == AllocationDecision.YES) {
             unassignedAllocationHandler.initialize(allocateUnassignedDecision.getTargetNode().getId(),
                 allocateUnassignedDecision.getAllocationId(),
-                shardRouting.primary() ? ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE :
-                                         allocation.clusterInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE),
+                getExpectedShardSize(shardRouting, allocation),
                 allocation.changes());
         } else {
             unassignedAllocationHandler.removeAndIgnore(allocateUnassignedDecision.getAllocationStatus(), allocation.changes());
         }
     }
 
+    protected long getExpectedShardSize(ShardRouting shardRouting, RoutingAllocation allocation) {
+        if (shardRouting.primary()) {
+            if (shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+                return allocation.snapshotShardSizeInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
+            } else {
+                return ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE;
+            }
+        } else {
+            return allocation.clusterInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
+        }
+    }
+
     /**
      * Make a decision on the allocation of an unassigned shard.  This method is used by
      * {@link #allocateUnassigned(ShardRouting, RoutingAllocation, ExistingShardsAllocator.UnassignedAllocationHandler)} to make decisions
diff --git a/server/src/main/java/org/elasticsearch/gateway/PrimaryShardAllocator.java b/server/src/main/java/org/elasticsearch/gateway/PrimaryShardAllocator.java
index 34398fbaf71..383b554fb21 100644
--- a/server/src/main/java/org/elasticsearch/gateway/PrimaryShardAllocator.java
+++ b/server/src/main/java/org/elasticsearch/gateway/PrimaryShardAllocator.java
@@ -27,6 +27,7 @@ import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.RoutingNodes;
 import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.UnassignedInfo.AllocationStatus;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.NodeAllocationResult;
@@ -83,6 +84,16 @@ public abstract class PrimaryShardAllocator extends BaseGatewayShardAllocator {
         }
 
         final boolean explain = allocation.debugDecision();
+
+        if (unassignedShard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT &&
+            allocation.snapshotShardSizeInfo().getShardSize(unassignedShard) == null) {
+            List<NodeAllocationResult> nodeDecisions = null;
+            if (explain) {
+                nodeDecisions = buildDecisionsForAllNodes(unassignedShard, allocation);
+            }
+            return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, nodeDecisions);
+        }
+
         final FetchResult<NodeGatewayStartedShards> shardState = fetchData(unassignedShard, allocation);
         if (shardState.hasData() == false) {
             allocation.setHasPendingAsyncFetch();
diff --git a/server/src/main/java/org/elasticsearch/node/Node.java b/server/src/main/java/org/elasticsearch/node/Node.java
index f585fc25ca8..ed5e9c843b0 100644
--- a/server/src/main/java/org/elasticsearch/node/Node.java
+++ b/server/src/main/java/org/elasticsearch/node/Node.java
@@ -161,8 +161,10 @@ import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.aggregations.support.AggregationUsageService;
 import org.elasticsearch.search.fetch.FetchPhase;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.RestoreService;
 import org.elasticsearch.snapshots.SnapshotShardsService;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.snapshots.SnapshotsService;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancellationService;
@@ -407,6 +409,7 @@ public class Node implements Closeable {
             final IngestService ingestService = new IngestService(clusterService, threadPool, this.environment,
                 scriptService, analysisModule.getAnalysisRegistry(),
                 pluginsService.filterPlugins(IngestPlugin.class), client);
+            final SetOnce<RepositoriesService> repositoriesServiceReference = new SetOnce<>();
             final ClusterInfoService clusterInfoService = newClusterInfoService(settings, clusterService, threadPool, client);
             final UsageService usageService = new UsageService();
 
@@ -418,7 +421,11 @@ public class Node implements Closeable {
             final MonitorService monitorService = new MonitorService(settings, nodeEnvironment, threadPool);
             final FsHealthService fsHealthService = new FsHealthService(settings, clusterService.getClusterSettings(), threadPool,
                 nodeEnvironment);
-            ClusterModule clusterModule = new ClusterModule(settings, clusterService, clusterPlugins, clusterInfoService);
+            final SetOnce<RerouteService> rerouteServiceReference = new SetOnce<>();
+            final InternalSnapshotsInfoService snapshotsInfoService = new InternalSnapshotsInfoService(settings, clusterService,
+                repositoriesServiceReference::get, rerouteServiceReference::get);
+            final ClusterModule clusterModule = new ClusterModule(settings, clusterService, clusterPlugins, clusterInfoService,
+                snapshotsInfoService);
             modules.add(clusterModule);
             IndicesModule indicesModule = new IndicesModule(pluginsService.filterPlugins(MapperPlugin.class));
             modules.add(indicesModule);
@@ -496,6 +503,7 @@ public class Node implements Closeable {
 
             final RerouteService rerouteService
                 = new BatchedRerouteService(clusterService, clusterModule.getAllocationService()::reroute);
+            rerouteServiceReference.set(rerouteService);
             clusterService.setRerouteService(rerouteService);
 
             final IndicesService indicesService =
@@ -529,7 +537,6 @@ public class Node implements Closeable {
             final MetadataCreateDataStreamService metadataCreateDataStreamService =
                 new MetadataCreateDataStreamService(threadPool, clusterService, metadataCreateIndexService);
 
-            final SetOnce<RepositoriesService> repositoriesServiceReference = new SetOnce<>();
             Collection<Object> pluginComponents = pluginsService.filterPlugins(Plugin.class).stream()
                 .flatMap(p -> p.createComponents(client, clusterService, threadPool, resourceWatcherService,
                                                  scriptService, xContentRegistry, environment, nodeEnvironment,
@@ -653,6 +660,7 @@ public class Node implements Closeable {
                     b.bind(UpdateHelper.class).toInstance(new UpdateHelper(scriptService));
                     b.bind(MetadataIndexUpgradeService.class).toInstance(metadataIndexUpgradeService);
                     b.bind(ClusterInfoService.class).toInstance(clusterInfoService);
+                    b.bind(SnapshotsInfoService.class).toInstance(snapshotsInfoService);
                     b.bind(GatewayMetaState.class).toInstance(gatewayMetaState);
                     b.bind(Discovery.class).toInstance(discoveryModule.getDiscovery());
                     {
@@ -1150,7 +1158,8 @@ public class Node implements Closeable {
     /** Constructs a ClusterInfoService which may be mocked for tests. */
     protected ClusterInfoService newClusterInfoService(Settings settings, ClusterService clusterService,
                                                        ThreadPool threadPool, NodeClient client) {
-        final InternalClusterInfoService service = new InternalClusterInfoService(settings, clusterService, threadPool, client);
+        final InternalClusterInfoService service =
+            new InternalClusterInfoService(settings, clusterService, threadPool, client);
         // listen for state changes (this node starts/stops being the elected master, or new nodes are added)
         clusterService.addListener(service);
         return service;
diff --git a/server/src/main/java/org/elasticsearch/snapshots/EmptySnapshotsInfoService.java b/server/src/main/java/org/elasticsearch/snapshots/EmptySnapshotsInfoService.java
new file mode 100644
index 00000000000..d81fee6a72f
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/snapshots/EmptySnapshotsInfoService.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+
+public class EmptySnapshotsInfoService implements SnapshotsInfoService {
+    public static final EmptySnapshotsInfoService INSTANCE = new EmptySnapshotsInfoService();
+
+    @Override
+    public SnapshotShardSizeInfo snapshotShardSizes() {
+        return new SnapshotShardSizeInfo(ImmutableOpenMap.of());
+    }
+}
diff --git a/server/src/main/java/org/elasticsearch/snapshots/InternalSnapshotsInfoService.java b/server/src/main/java/org/elasticsearch/snapshots/InternalSnapshotsInfoService.java
new file mode 100644
index 00000000000..355b6b168eb
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/snapshots/InternalSnapshotsInfoService.java
@@ -0,0 +1,386 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import com.carrotsearch.hppc.cursors.ObjectCursor;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.ClusterChangedEvent;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStateListener;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.RerouteService;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Priority;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+import org.elasticsearch.common.settings.ClusterSettings;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.repositories.RepositoriesService;
+import org.elasticsearch.repositories.Repository;
+import org.elasticsearch.threadpool.ThreadPool;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.Set;
+import java.util.function.Supplier;
+
+public class InternalSnapshotsInfoService implements ClusterStateListener, SnapshotsInfoService {
+
+    public static final Setting<Integer> INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING =
+        Setting.intSetting("cluster.snapshot.info.max_concurrent_fetches", 5, 1,
+            Setting.Property.Dynamic, Setting.Property.NodeScope);
+
+    private static final Logger logger = LogManager.getLogger(InternalSnapshotsInfoService.class);
+
+    private static final ActionListener<ClusterState> REROUTE_LISTENER = ActionListener.wrap(
+        r -> logger.trace("reroute after snapshot shard size update completed"),
+        e -> logger.debug("reroute after snapshot shard size update failed", e)
+    );
+
+    private final ThreadPool threadPool;
+    private final Supplier<RepositoriesService> repositoriesService;
+    private final Supplier<RerouteService> rerouteService;
+
+    /** contains the snapshot shards for which the size is known **/
+    private volatile ImmutableOpenMap<SnapshotShard, Long> knownSnapshotShardSizes;
+
+    private volatile boolean isMaster;
+
+    /** contains the snapshot shards for which the size is unknown and must be fetched (or is being fetched) **/
+    private final Set<SnapshotShard> unknownSnapshotShards;
+
+    /** a blocking queue used for concurrent fetching **/
+    private final Queue<SnapshotShard> queue;
+
+    /** contains the snapshot shards for which the snapshot shard size retrieval failed **/
+    private final Set<SnapshotShard> failedSnapshotShards;
+
+    private volatile int maxConcurrentFetches;
+    private int activeFetches;
+
+    private final Object mutex;
+
+    public InternalSnapshotsInfoService(
+        final Settings settings,
+        final ClusterService clusterService,
+        final Supplier<RepositoriesService> repositoriesServiceSupplier,
+        final Supplier<RerouteService> rerouteServiceSupplier
+    ) {
+        this.threadPool = clusterService.getClusterApplierService().threadPool();
+        this.repositoriesService = repositoriesServiceSupplier;
+        this.rerouteService = rerouteServiceSupplier;
+        this.knownSnapshotShardSizes = ImmutableOpenMap.of();
+        this.unknownSnapshotShards  = new LinkedHashSet<>();
+        this.failedSnapshotShards  = new LinkedHashSet<>();
+        this.queue = new LinkedList<>();
+        this.mutex = new Object();
+        this.activeFetches = 0;
+        this.maxConcurrentFetches = INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.get(settings);
+        final ClusterSettings clusterSettings = clusterService.getClusterSettings();
+        clusterSettings.addSettingsUpdateConsumer(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING, this::setMaxConcurrentFetches);
+        if (DiscoveryNode.isMasterNode(settings)) {
+            clusterService.addListener(this);
+        }
+    }
+
+    private void setMaxConcurrentFetches(Integer maxConcurrentFetches) {
+        this.maxConcurrentFetches = maxConcurrentFetches;
+    }
+
+    @Override
+    public SnapshotShardSizeInfo snapshotShardSizes() {
+        synchronized (mutex){
+            final ImmutableOpenMap.Builder<SnapshotShard, Long> snapshotShardSizes = ImmutableOpenMap.builder(knownSnapshotShardSizes);
+            for (SnapshotShard snapshotShard : failedSnapshotShards) {
+                Long previous = snapshotShardSizes.put(snapshotShard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
+                assert previous == null : "snapshot shard size already known for " + snapshotShard;
+            }
+            return new SnapshotShardSizeInfo(snapshotShardSizes.build());
+        }
+    }
+
+    @Override
+    public void clusterChanged(ClusterChangedEvent event) {
+        if (event.localNodeMaster()) {
+            final Set<SnapshotShard> onGoingSnapshotRecoveries = listOfSnapshotShards(event.state());
+
+            int unknownShards = 0;
+            synchronized (mutex) {
+                isMaster = true;
+                for (SnapshotShard snapshotShard : onGoingSnapshotRecoveries) {
+                    // check if already populated entry
+                    if (knownSnapshotShardSizes.containsKey(snapshotShard) == false) {
+                        // check if already fetching snapshot info in progress
+                        if (unknownSnapshotShards.add(snapshotShard)) {
+                            failedSnapshotShards.remove(snapshotShard); // retry the failed shard
+                            queue.add(snapshotShard);
+                            unknownShards += 1;
+                        }
+                    }
+                }
+                // Clean up keys from knownSnapshotShardSizes that are no longer needed for recoveries
+                cleanUpKnownSnapshotShardSizes(onGoingSnapshotRecoveries);
+            }
+
+            final int nbFetchers = Math.min(unknownShards, maxConcurrentFetches);
+            for (int i = 0; i < nbFetchers; i++) {
+                fetchNextSnapshotShard();
+            }
+
+        } else if (event.previousState().nodes().isLocalNodeElectedMaster()) {
+            // TODO Maybe just clear out non-ongoing snapshot recoveries is the node is master eligible, so that we don't
+            // have to repopulate the data over and over in an unstable master situation?
+            synchronized (mutex) {
+                // information only needed on current master
+                knownSnapshotShardSizes = ImmutableOpenMap.of();
+                failedSnapshotShards.clear();
+                isMaster = false;
+                SnapshotShard snapshotShard;
+                while ((snapshotShard = queue.poll()) != null) {
+                    final boolean removed = unknownSnapshotShards.remove(snapshotShard);
+                    assert removed : "snapshot shard to remove does not exist " + snapshotShard;
+                }
+                assert invariant();
+            }
+        } else {
+            synchronized (mutex) {
+                assert unknownSnapshotShards.isEmpty() || unknownSnapshotShards.size() == activeFetches;
+                assert knownSnapshotShardSizes.isEmpty();
+                assert failedSnapshotShards.isEmpty();
+                assert isMaster == false;
+                assert queue.isEmpty();
+            }
+        }
+    }
+
+    private void fetchNextSnapshotShard() {
+        synchronized (mutex) {
+            if (activeFetches < maxConcurrentFetches) {
+                final SnapshotShard snapshotShard = queue.poll();
+                if (snapshotShard != null) {
+                    activeFetches += 1;
+                    threadPool.generic().execute(new FetchingSnapshotShardSizeRunnable(snapshotShard));
+                }
+            }
+            assert invariant();
+        }
+    }
+
+    private class FetchingSnapshotShardSizeRunnable extends AbstractRunnable {
+
+        private final SnapshotShard snapshotShard;
+        private boolean removed;
+
+        FetchingSnapshotShardSizeRunnable(SnapshotShard snapshotShard) {
+            super();
+            this.snapshotShard = snapshotShard;
+            this.removed = false;
+        }
+
+        @Override
+        protected void doRun() throws Exception {
+            final RepositoriesService repositories = repositoriesService.get();
+            assert repositories != null;
+            final Repository repository = repositories.repository(snapshotShard.snapshot.getRepository());
+
+            logger.debug("fetching snapshot shard size for {}", snapshotShard);
+            final long snapshotShardSize = repository.getShardSnapshotStatus(
+                snapshotShard.snapshot().getSnapshotId(),
+                snapshotShard.index(),
+                snapshotShard.shardId()
+            ).asCopy().getTotalSize();
+
+            logger.debug("snapshot shard size for {}: {} bytes", snapshotShard, snapshotShardSize);
+
+            boolean updated = false;
+            synchronized (mutex) {
+                removed = unknownSnapshotShards.remove(snapshotShard);
+                assert removed : "snapshot shard to remove does not exist " + snapshotShardSize;
+                if (isMaster) {
+                    final ImmutableOpenMap.Builder<SnapshotShard, Long> newSnapshotShardSizes =
+                        ImmutableOpenMap.builder(knownSnapshotShardSizes);
+                    updated = newSnapshotShardSizes.put(snapshotShard, snapshotShardSize) == null;
+                    assert updated : "snapshot shard size already exists for " + snapshotShard;
+                    knownSnapshotShardSizes = newSnapshotShardSizes.build();
+                }
+                activeFetches -= 1;
+                assert invariant();
+            }
+            if (updated) {
+                rerouteService.get().reroute("snapshot shard size updated", Priority.HIGH, REROUTE_LISTENER);
+            }
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            logger.warn(() -> new ParameterizedMessage("failed to retrieve shard size for {}", snapshotShard), e);
+            synchronized (mutex) {
+                if (isMaster) {
+                    final boolean added = failedSnapshotShards.add(snapshotShard);
+                    assert added : "snapshot shard size already failed for " + snapshotShard;
+                }
+                if (removed == false) {
+                    unknownSnapshotShards.remove(snapshotShard);
+                }
+                activeFetches -= 1;
+                assert invariant();
+            }
+        }
+
+        @Override
+        public void onAfter() {
+            fetchNextSnapshotShard();
+        }
+    }
+
+    private void cleanUpKnownSnapshotShardSizes(Set<SnapshotShard> requiredSnapshotShards) {
+        assert Thread.holdsLock(mutex);
+        ImmutableOpenMap.Builder<SnapshotShard, Long> newSnapshotShardSizes = null;
+        for (ObjectCursor<SnapshotShard> shard : knownSnapshotShardSizes.keys()) {
+            if (requiredSnapshotShards.contains(shard.value) == false) {
+                if (newSnapshotShardSizes == null) {
+                    newSnapshotShardSizes = ImmutableOpenMap.builder(knownSnapshotShardSizes);
+                }
+                newSnapshotShardSizes.remove(shard.value);
+            }
+        }
+        if (newSnapshotShardSizes != null) {
+            knownSnapshotShardSizes = newSnapshotShardSizes.build();
+        }
+    }
+
+    private boolean invariant() {
+        assert Thread.holdsLock(mutex);
+        assert activeFetches >= 0 : "active fetches should be greater than or equal to zero but got: " + activeFetches;
+        assert activeFetches <= maxConcurrentFetches : activeFetches + " <= " + maxConcurrentFetches;
+        for (ObjectCursor<SnapshotShard> cursor : knownSnapshotShardSizes.keys()) {
+            assert unknownSnapshotShards.contains(cursor.value) == false : "cannot be known and unknown at same time: " + cursor.value;
+            assert failedSnapshotShards.contains(cursor.value) == false : "cannot be known and failed at same time: " + cursor.value;
+        }
+        for (SnapshotShard shard : unknownSnapshotShards) {
+            assert knownSnapshotShardSizes.keys().contains(shard) == false : "cannot be unknown and known at same time: " + shard;
+            assert failedSnapshotShards.contains(shard) == false : "cannot be unknown and failed at same time: " + shard;
+        }
+        for (SnapshotShard shard : failedSnapshotShards) {
+            assert knownSnapshotShardSizes.keys().contains(shard) == false : "cannot be failed and known at same time: " + shard;
+            assert unknownSnapshotShards.contains(shard) == false : "cannot be failed and unknown at same time: " + shard;
+        }
+        return true;
+    }
+
+    // used in tests
+    int numberOfUnknownSnapshotShardSizes() {
+        synchronized (mutex) {
+            return unknownSnapshotShards.size();
+        }
+    }
+
+    // used in tests
+    int numberOfFailedSnapshotShardSizes() {
+        synchronized (mutex) {
+            return failedSnapshotShards.size();
+        }
+    }
+
+    // used in tests
+    int numberOfKnownSnapshotShardSizes() {
+        return knownSnapshotShardSizes.size();
+    }
+
+    private static Set<SnapshotShard> listOfSnapshotShards(final ClusterState state) {
+        final Set<SnapshotShard> snapshotShards = new HashSet<>();
+        for (ShardRouting shardRouting : state.routingTable().shardsWithState(ShardRoutingState.UNASSIGNED)) {
+            if (shardRouting.primary() && shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+                final RecoverySource.SnapshotRecoverySource snapshotRecoverySource =
+                    (RecoverySource.SnapshotRecoverySource) shardRouting.recoverySource();
+                final SnapshotShard snapshotShard = new SnapshotShard(snapshotRecoverySource.snapshot(),
+                    snapshotRecoverySource.index(), shardRouting.shardId());
+                snapshotShards.add(snapshotShard);
+            }
+        }
+        return Collections.unmodifiableSet(snapshotShards);
+    }
+
+    public static class SnapshotShard {
+
+        private final Snapshot snapshot;
+        private final IndexId index;
+        private final ShardId shardId;
+
+        public SnapshotShard(Snapshot snapshot, IndexId index, ShardId shardId) {
+            this.snapshot = snapshot;
+            this.index = index;
+            this.shardId = shardId;
+        }
+
+        public Snapshot snapshot() {
+            return snapshot;
+        }
+
+        public IndexId index() {
+            return index;
+        }
+
+        public ShardId shardId() {
+            return shardId;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+            final SnapshotShard that = (SnapshotShard) o;
+            return shardId.equals(that.shardId)
+                && snapshot.equals(that.snapshot)
+                && index.equals(that.index);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(snapshot, index, shardId);
+        }
+
+        @Override
+        public String toString() {
+            return "[" +
+                "snapshot=" + snapshot +
+                ", index=" + index +
+                ", shard=" + shardId +
+                ']';
+        }
+    }
+}
diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardSizeInfo.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardSizeInfo.java
new file mode 100644
index 00000000000..0534b62ea07
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardSizeInfo.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+
+public class SnapshotShardSizeInfo {
+
+    public static final SnapshotShardSizeInfo EMPTY = new SnapshotShardSizeInfo(ImmutableOpenMap.of());
+
+    private final ImmutableOpenMap<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes;
+
+    public SnapshotShardSizeInfo(ImmutableOpenMap<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes) {
+        this.snapshotShardSizes = snapshotShardSizes;
+    }
+
+    public Long getShardSize(ShardRouting shardRouting) {
+        if (shardRouting.primary()
+            && shardRouting.active() == false
+            && shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+            final RecoverySource.SnapshotRecoverySource snapshotRecoverySource =
+                (RecoverySource.SnapshotRecoverySource) shardRouting.recoverySource();
+            return snapshotShardSizes.get(new InternalSnapshotsInfoService.SnapshotShard(
+                snapshotRecoverySource.snapshot(), snapshotRecoverySource.index(), shardRouting.shardId()));
+        }
+        assert false : "Expected shard with snapshot recovery source but was " + shardRouting;
+        return null;
+    }
+
+    public long getShardSize(ShardRouting shardRouting, long fallback) {
+        final Long shardSize = getShardSize(shardRouting);
+        return shardSize == null ? fallback : shardSize;
+    }
+}
diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsInfoService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsInfoService.java
new file mode 100644
index 00000000000..0b925b31f86
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsInfoService.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+@FunctionalInterface
+public interface SnapshotsInfoService {
+    SnapshotShardSizeInfo snapshotShardSizes();
+}
diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java
index 2e7e76702d4..53d66d4a9b5 100644
--- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java
+++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java
@@ -56,7 +56,7 @@ public class ClusterAllocationExplainActionTests extends ESTestCase {
         ClusterState clusterState = ClusterStateCreationUtils.state("idx", randomBoolean(), shardRoutingState);
         ShardRouting shard = clusterState.getRoutingTable().index("idx").shard(0).primaryShard();
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            clusterState.getRoutingNodes(), clusterState, null, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, null, null, System.nanoTime());
         ClusterAllocationExplanation cae = TransportClusterAllocationExplainAction.explainShard(shard, allocation, null, randomBoolean(),
             new AllocationService(null, new TestGatewayAllocator(), new ShardsAllocator() {
                 @Override
@@ -72,7 +72,7 @@ public class ClusterAllocationExplainActionTests extends ESTestCase {
                         throw new UnsupportedOperationException("cannot explain");
                     }
                 }
-            }, null));
+            }, null, null));
 
         assertEquals(shard.currentNodeId(), cae.getCurrentNode().getId());
         assertFalse(cae.getShardAllocationDecision().isDecisionTaken());
@@ -178,6 +178,6 @@ public class ClusterAllocationExplainActionTests extends ESTestCase {
     }
 
     private static RoutingAllocation routingAllocation(ClusterState clusterState) {
-        return new RoutingAllocation(NOOP_DECIDERS, clusterState.getRoutingNodes(), clusterState, null, System.nanoTime());
+        return new RoutingAllocation(NOOP_DECIDERS, clusterState.getRoutingNodes(), clusterState, null, null, System.nanoTime());
     }
 }
diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java
index 78148c65a5f..a6400e54b08 100644
--- a/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java
+++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java
@@ -41,6 +41,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.io.IOException;
@@ -81,7 +82,8 @@ public class ClusterRerouteTests extends ESAllocationTestCase {
     public void testClusterStateUpdateTask() {
         AllocationService allocationService = new AllocationService(
             new AllocationDeciders(Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         ClusterState clusterState = createInitialClusterState(allocationService);
         ClusterRerouteRequest req = new ClusterRerouteRequest();
         req.dryRun(true);
diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/shrink/TransportResizeActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/shrink/TransportResizeActionTests.java
index 18b55660ace..eedf4ac3759 100644
--- a/server/src/test/java/org/elasticsearch/action/admin/indices/shrink/TransportResizeActionTests.java
+++ b/server/src/test/java/org/elasticsearch/action/admin/indices/shrink/TransportResizeActionTests.java
@@ -40,6 +40,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.shard.DocsStats;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
@@ -111,7 +112,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -129,7 +131,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -158,7 +161,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -192,7 +196,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
diff --git a/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java b/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java
index 98e8970dfb5..bd4761acf97 100644
--- a/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java
@@ -121,7 +121,7 @@ public class ClusterModuleTests extends ModuleTestCase {
                     public Collection<AllocationDecider> createAllocationDeciders(Settings settings, ClusterSettings clusterSettings) {
                         return Collections.singletonList(new EnableAllocationDecider(settings, clusterSettings));
                     }
-                }), clusterInfoService));
+                }), clusterInfoService, null));
         assertEquals(e.getMessage(),
             "Cannot specify allocation decider [" + EnableAllocationDecider.class.getName() + "] twice");
     }
@@ -133,7 +133,7 @@ public class ClusterModuleTests extends ModuleTestCase {
                 public Collection<AllocationDecider> createAllocationDeciders(Settings settings, ClusterSettings clusterSettings) {
                     return Collections.singletonList(new FakeAllocationDecider());
                 }
-            }), clusterInfoService);
+            }), clusterInfoService, null);
         assertTrue(module.deciderList.stream().anyMatch(d -> d.getClass().equals(FakeAllocationDecider.class)));
     }
 
@@ -145,7 +145,7 @@ public class ClusterModuleTests extends ModuleTestCase {
                     return Collections.singletonMap(name, supplier);
                 }
             }
-        ), clusterInfoService);
+        ), clusterInfoService, null);
     }
 
     public void testRegisterShardsAllocator() {
@@ -163,7 +163,7 @@ public class ClusterModuleTests extends ModuleTestCase {
     public void testUnknownShardsAllocator() {
         Settings settings = Settings.builder().put(ClusterModule.SHARDS_ALLOCATOR_TYPE_SETTING.getKey(), "dne").build();
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () ->
-            new ClusterModule(settings, clusterService, Collections.emptyList(), clusterInfoService));
+            new ClusterModule(settings, clusterService, Collections.emptyList(), clusterInfoService, null));
         assertEquals("Unknown ShardsAllocator [dne]", e.getMessage());
     }
 
@@ -231,13 +231,14 @@ public class ClusterModuleTests extends ModuleTestCase {
 
     public void testRejectsReservedExistingShardsAllocatorName() {
         final ClusterModule clusterModule = new ClusterModule(Settings.EMPTY, clusterService,
-            Collections.singletonList(existingShardsAllocatorPlugin(GatewayAllocator.ALLOCATOR_NAME)), clusterInfoService);
+            Collections.singletonList(existingShardsAllocatorPlugin(GatewayAllocator.ALLOCATOR_NAME)), clusterInfoService, null);
         expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator()));
     }
 
     public void testRejectsDuplicateExistingShardsAllocatorName() {
         final ClusterModule clusterModule = new ClusterModule(Settings.EMPTY, clusterService,
-            Arrays.asList(existingShardsAllocatorPlugin("duplicate"), existingShardsAllocatorPlugin("duplicate")), clusterInfoService);
+            Arrays.asList(existingShardsAllocatorPlugin("duplicate"), existingShardsAllocatorPlugin("duplicate")), clusterInfoService,
+            null);
         expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator()));
     }
 
diff --git a/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java b/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java
index 9d94943ed77..f4a20fc64c9 100644
--- a/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java
@@ -143,7 +143,7 @@ public class ClusterStateHealthTests extends ESTestCase {
 
         TransportClusterHealthAction action = new TransportClusterHealthAction(transportService,
             clusterService, threadPool, new ActionFilters(new HashSet<>()), indexNameExpressionResolver,
-            new AllocationService(null, new TestGatewayAllocator(), null, null));
+            new AllocationService(null, new TestGatewayAllocator(), null, null, null));
         PlainActionFuture<ClusterHealthResponse> listener = new PlainActionFuture<>();
         action.execute(new ClusterHealthRequest().waitForGreenStatus(), listener);
 
diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
index aa9c7ab9b2c..bd84243b36f 100644
--- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
@@ -60,6 +60,7 @@ import org.elasticsearch.indices.InvalidIndexNameException;
 import org.elasticsearch.indices.ShardLimitValidator;
 import org.elasticsearch.indices.SystemIndexDescriptor;
 import org.elasticsearch.indices.SystemIndices;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ClusterServiceUtils;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.VersionUtils;
@@ -232,7 +233,8 @@ public class MetadataCreateIndexServiceTests extends ESTestCase {
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
             singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -294,7 +296,8 @@ public class MetadataCreateIndexServiceTests extends ESTestCase {
             .nodes(DiscoveryNodes.builder().add(newNode("node1"))).build();
         AllocationService service = new AllocationService(new AllocationDeciders(
             singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -430,7 +433,8 @@ public class MetadataCreateIndexServiceTests extends ESTestCase {
                 new AllocationDeciders(singleton(new MaxRetryAllocationDecider())),
                 new TestGatewayAllocator(),
                 new BalancedShardsAllocator(Settings.EMPTY),
-                EmptyClusterInfoService.INSTANCE);
+                EmptyClusterInfoService.INSTANCE,
+                EmptySnapshotsInfoService.INSTANCE);
 
         final RoutingTable initialRoutingTable = service.reroute(initialClusterState, "reroute").routingTable();
         final ClusterState routingTableClusterState = ClusterState.builder(initialClusterState).routingTable(initialRoutingTable).build();
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationCommandsTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationCommandsTests.java
index 263be83d437..ed656e848e7 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationCommandsTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationCommandsTests.java
@@ -59,6 +59,7 @@ import org.elasticsearch.index.Index;
 import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardNotFoundException;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -635,7 +636,7 @@ public class AllocationCommandsTests extends ESAllocationTestCase {
         Index index = clusterState.getMetadata().index("test").getIndex();
         MoveAllocationCommand command = new MoveAllocationCommand(index.getName(), 0, "node1", "node2");
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         logger.info("--> executing move allocation command to non-data node");
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> command.execute(routingAllocation, false));
         assertEquals("[move_allocation] can't move [test][0] from " + node1 + " to " +
@@ -674,7 +675,7 @@ public class AllocationCommandsTests extends ESAllocationTestCase {
         Index index = clusterState.getMetadata().index("test").getIndex();
         MoveAllocationCommand command = new MoveAllocationCommand(index.getName(), 0, "node2", "node1");
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         logger.info("--> executing move allocation command from non-data node");
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> command.execute(routingAllocation, false));
         assertEquals("[move_allocation] can't move [test][0] from " + node2 + " to " + node1 +
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationServiceTests.java
index dbca0127337..810a5a9b044 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationServiceTests.java
@@ -42,6 +42,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.ThrottlingAllocation
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.gateway.GatewayAllocator;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
@@ -135,7 +136,7 @@ public class AllocationServiceTests extends ESTestCase {
                 public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
                     return ShardAllocationDecision.NOT_TAKEN;
                 }
-            }, new EmptyClusterInfoService());
+            }, new EmptyClusterInfoService(), EmptySnapshotsInfoService.INSTANCE);
 
         final String unrealisticAllocatorName = "unrealistic";
         final Map<String, ExistingShardsAllocator> allocatorMap = new HashMap<>();
@@ -222,7 +223,7 @@ public class AllocationServiceTests extends ESTestCase {
     }
 
     public void testExplainsNonAllocationOfShardWithUnknownAllocator() {
-        final AllocationService allocationService = new AllocationService(null, null, null);
+        final AllocationService allocationService = new AllocationService(null, null, null, null);
         allocationService.setExistingShardsAllocators(
             Collections.singletonMap(GatewayAllocator.ALLOCATOR_NAME, new TestGatewayAllocator()));
 
@@ -242,7 +243,7 @@ public class AllocationServiceTests extends ESTestCase {
             .build();
 
         final RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            clusterState.getRoutingNodes(), clusterState, ClusterInfo.EMPTY, 0L);
+            clusterState.getRoutingNodes(), clusterState, ClusterInfo.EMPTY, null,0L);
         allocation.setDebugMode(randomBoolean() ? RoutingAllocation.DebugMode.ON : RoutingAllocation.DebugMode.EXCLUDE_YES_DECISIONS);
 
         final ShardAllocationDecision shardAllocationDecision
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalanceConfigurationTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalanceConfigurationTests.java
index 52737b41464..7418895bba6 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalanceConfigurationTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalanceConfigurationTests.java
@@ -41,6 +41,7 @@ import org.elasticsearch.cluster.routing.allocation.allocator.ShardsAllocator;
 import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.hamcrest.Matchers;
 
@@ -347,7 +348,7 @@ public class BalanceConfigurationTests extends ESAllocationTestCase {
             public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
                 throw new UnsupportedOperationException("explain not supported");
             }
-        }, EmptyClusterInfoService.INSTANCE);
+        }, EmptyClusterInfoService.INSTANCE, EmptySnapshotsInfoService.INSTANCE);
         Metadata.Builder metadataBuilder = Metadata.builder();
         RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
         IndexMetadata.Builder indexMeta = IndexMetadata.builder("test")
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalancedSingleShardTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalancedSingleShardTests.java
index e4a7fa47025..d8474078c6c 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalancedSingleShardTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalancedSingleShardTests.java
@@ -37,6 +37,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision.Type;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -368,7 +369,7 @@ public class BalancedSingleShardTests extends ESAllocationTestCase {
 
     private RoutingAllocation newRoutingAllocation(AllocationDeciders deciders, ClusterState state) {
         RoutingAllocation allocation = new RoutingAllocation(
-            deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, System.nanoTime());
+            deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         allocation.debugDecision(true);
         return allocation;
     }
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DecisionsImpactOnClusterHealthTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DecisionsImpactOnClusterHealthTests.java
index 4b469e69882..414dc259cf0 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DecisionsImpactOnClusterHealthTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DecisionsImpactOnClusterHealthTests.java
@@ -39,6 +39,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.env.Environment;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.io.IOException;
@@ -155,7 +156,8 @@ public class DecisionsImpactOnClusterHealthTests extends ESAllocationTestCase {
         return new AllocationService(new AllocationDeciders(deciders),
                                      new TestGatewayAllocator(),
                                      new BalancedShardsAllocator(settings),
-                                     EmptyClusterInfoService.INSTANCE);
+                                     EmptyClusterInfoService.INSTANCE,
+                                     EmptySnapshotsInfoService.INSTANCE);
     }
 
 }
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java
index 602c83b93fb..1346d3c8d92 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java
@@ -35,6 +35,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.util.Collections;
@@ -56,7 +57,8 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
         super.setUp();
         strategy = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
     }
 
     private ClusterState createInitialClusterState() {
@@ -176,7 +178,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
             assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom" + i));
             // MaxRetryAllocationDecider#canForceAllocatePrimary should return YES decisions because canAllocate returns YES here
             assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, 0)));
+                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, null,0)));
         }
         // now we go and check that we are actually stick to unassigned on the next failure
         {
@@ -194,7 +196,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
             assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom"));
             // MaxRetryAllocationDecider#canForceAllocatePrimary should return a NO decision because canAllocate returns NO here
             assertEquals(Decision.NO, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, 0)));
+                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, null,0)));
         }
 
         // change the settings and ensure we can do another round of allocation for that index.
@@ -216,7 +218,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
         assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom"));
         // bumped up the max retry count, so canForceAllocatePrimary should return a YES decision
         assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-            routingTable.index("idx").shard(0).shards().get(0), null, new RoutingAllocation(null, null, clusterState, null, 0)));
+            routingTable.index("idx").shard(0).shards().get(0), null, new RoutingAllocation(null, null, clusterState, null, null,0)));
 
         // now we start the shard
         clusterState = startShardsAndReroute(strategy, clusterState, routingTable.index("idx").shard(0).shards().get(0));
@@ -242,7 +244,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
         assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("ZOOOMG"));
         // Counter reset, so MaxRetryAllocationDecider#canForceAllocatePrimary should return a YES decision
         assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-            unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, 0)));
+            unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, null,0)));
     }
 
 }
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/NodeVersionAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/NodeVersionAllocationDeciderTests.java
index 7f2e6673cc5..d6845fc597c 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/NodeVersionAllocationDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/NodeVersionAllocationDeciderTests.java
@@ -50,12 +50,17 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.NodeVersionAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.ReplicaAfterPrimaryActiveAllocationDecider;
 import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
@@ -338,7 +343,8 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
             Collections.singleton(new NodeVersionAllocationDecider()));
         AllocationService strategy = new MockAllocationService(
             allocationDeciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         state = strategy.reroute(state, new AllocationCommands(), true, false).getClusterState();
         // the two indices must stay as is, the replicas cannot move to oldNode2 because versions don't match
         assertThat(state.routingTable().index(shard2.getIndex()).shardsWithState(ShardRoutingState.RELOCATING).size(), equalTo(0));
@@ -353,7 +359,10 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         final DiscoveryNode oldNode2 = new DiscoveryNode("oldNode2", buildNewFakeTransportAddress(), emptyMap(),
                 MASTER_DATA_ROLES, VersionUtils.getPreviousVersion());
 
-        int numberOfShards = randomIntBetween(1, 3);
+        final Snapshot snapshot = new Snapshot("rep1", new SnapshotId("snp1", UUIDs.randomBase64UUID()));
+        final IndexId indexId = new IndexId("test", UUIDs.randomBase64UUID(random()));
+
+        final int numberOfShards = randomIntBetween(1, 3);
         final IndexMetadata.Builder indexMetadata = IndexMetadata.builder("test").settings(settings(Version.CURRENT))
             .numberOfShards(numberOfShards).numberOfReplicas(randomIntBetween(0, 3));
         for (int i = 0; i < numberOfShards; i++) {
@@ -361,20 +370,26 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         }
         Metadata metadata = Metadata.builder().put(indexMetadata).build();
 
+        final ImmutableOpenMap.Builder<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes =
+            ImmutableOpenMap.builder(numberOfShards);
+        final Index index = metadata.index("test").getIndex();
+        for (int i = 0; i < numberOfShards; i++) {
+            final ShardId shardId = new ShardId(index, i);
+            snapshotShardSizes.put(new InternalSnapshotsInfoService.SnapshotShard(snapshot, indexId, shardId), randomNonNegativeLong());
+        }
+
         ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
             .metadata(metadata)
             .routingTable(RoutingTable.builder().addAsRestore(metadata.index("test"),
-                new SnapshotRecoverySource(
-                    UUIDs.randomBase64UUID(),
-                    new Snapshot("rep1", new SnapshotId("snp1", UUIDs.randomBase64UUID())),
-                Version.CURRENT, new IndexId("test", UUIDs.randomBase64UUID(random())))).build())
+                new SnapshotRecoverySource(UUIDs.randomBase64UUID(), snapshot, Version.CURRENT, indexId)).build())
             .nodes(DiscoveryNodes.builder().add(newNode).add(oldNode1).add(oldNode2)).build();
         AllocationDeciders allocationDeciders = new AllocationDeciders(Arrays.asList(
             new ReplicaAfterPrimaryActiveAllocationDecider(),
             new NodeVersionAllocationDecider()));
         AllocationService strategy = new MockAllocationService(
             allocationDeciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            () -> new SnapshotShardSizeInfo(snapshotShardSizes.build()));
         state = strategy.reroute(state, new AllocationCommands(), true, false).getClusterState();
 
         // Make sure that primary shards are only allocated on the new node
@@ -463,7 +478,7 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         final ShardRouting replicaShard = clusterState.routingTable().shardRoutingTable(shardId).replicaShards().get(0);
 
         RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState,
-            null, 0);
+            null, null, 0);
         routingAllocation.debugDecision(true);
 
         final NodeVersionAllocationDecider allocationDecider = new NodeVersionAllocationDecider();
@@ -508,7 +523,7 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         final ShardRouting startedPrimary = routingNodes.startShard(logger,
             routingNodes.initializeShard(primaryShard, "newNode", null, 0,
             routingChangesObserver), routingChangesObserver);
-        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, 0);
+        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, null,0);
         routingAllocation.debugDecision(true);
 
         decision = allocationDecider.canAllocate(replicaShard, oldNode, routingAllocation);
@@ -518,7 +533,7 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
 
         routingNodes.startShard(logger, routingNodes.relocateShard(startedPrimary,
             "oldNode", 0, routingChangesObserver).v2(), routingChangesObserver);
-        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, 0);
+        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, null,0);
         routingAllocation.debugDecision(true);
 
         decision = allocationDecider.canAllocate(replicaShard, newNode, routingAllocation);
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/RandomAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/RandomAllocationDeciderTests.java
index 95e3a4139e2..b8b635fbc45 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/RandomAllocationDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/RandomAllocationDeciderTests.java
@@ -39,6 +39,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.ReplicaAfterPrimaryA
 import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationDecider;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.hamcrest.Matchers;
 
@@ -62,7 +63,8 @@ public class RandomAllocationDeciderTests extends ESAllocationTestCase {
                 new HashSet<>(Arrays.asList(new SameShardAllocationDecider(Settings.EMPTY,
                         new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)),
                     new ReplicaAfterPrimaryActiveAllocationDecider(), randomAllocationDecider))),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         int indices = scaledRandomIntBetween(1, 20);
         Builder metaBuilder = Metadata.builder();
         int maxNumReplicas = 1;
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ResizeAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ResizeAllocationDeciderTests.java
index 4af5d00422b..92776fd53af 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ResizeAllocationDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ResizeAllocationDeciderTests.java
@@ -38,6 +38,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.ResizeAllocationDeci
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.util.Collections;
@@ -56,7 +57,8 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
         super.setUp();
         strategy = new AllocationService(new AllocationDeciders(
             Collections.singleton(new ResizeAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
     }
 
     private ClusterState createInitialClusterState(boolean startShards) {
@@ -104,7 +106,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
     public void testNonResizeRouting() {
         ClusterState clusterState = createInitialClusterState(true);
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, null, 0);
         ShardRouting shardRouting = TestShardRouting.newShardRouting("non-resize", 0, null, true, ShardRoutingState.UNASSIGNED);
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, routingAllocation));
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, clusterState.getRoutingNodes().node("node1"),
@@ -128,7 +130,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
         Index idx = clusterState.metadata().index("target").getIndex();
 
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, null, 0);
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, 0), null, true, ShardRoutingState.UNASSIGNED,
             RecoverySource.LocalShardsRecoverySource.INSTANCE);
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, routingAllocation));
@@ -156,7 +158,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
 
 
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, null, 0);
         int shardId = randomIntBetween(0, 3);
         int sourceShardId = IndexMetadata.selectSplitShard(shardId, clusterState.metadata().index("source"), 4).id();
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, shardId), null, true, ShardRoutingState.UNASSIGNED,
@@ -196,7 +198,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
 
 
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, null, 0);
         int shardId = randomIntBetween(0, 3);
         int sourceShardId = IndexMetadata.selectSplitShard(shardId, clusterState.metadata().index("source"), 4).id();
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, shardId), null, true, ShardRoutingState.UNASSIGNED,
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/SameShardRoutingTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/SameShardRoutingTests.java
index f70eb86a9fc..15463521f35 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/SameShardRoutingTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/SameShardRoutingTests.java
@@ -43,6 +43,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationD
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.Collections;
 
@@ -106,7 +107,7 @@ public class SameShardRoutingTests extends ESAllocationTestCase {
         ShardRouting primaryShard = clusterState.routingTable().index(index).shard(0).primaryShard();
         RoutingNode routingNode = clusterState.getRoutingNodes().node(primaryShard.currentNodeId());
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
 
         // can't force allocate same shard copy to the same node
         ShardRouting newPrimary = TestShardRouting.newShardRouting(primaryShard.shardId(), null, true, ShardRoutingState.UNASSIGNED);
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ThrottlingAllocationTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ThrottlingAllocationTests.java
index 9e69d4d6be9..c05073f9e15 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ThrottlingAllocationTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ThrottlingAllocationTests.java
@@ -21,7 +21,6 @@ package org.elasticsearch.cluster.routing.allocation;
 
 import com.carrotsearch.hppc.IntHashSet;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
-
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.Version;
@@ -48,8 +47,11 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.util.ArrayList;
@@ -70,10 +72,11 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
     public void testPrimaryRecoveryThrottling() {
 
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         AllocationService strategy = createAllocationService(Settings.builder()
                 .put("cluster.routing.allocation.node_concurrent_recoveries", 3)
                 .put("cluster.routing.allocation.node_initial_primaries_recoveries", 3)
-                .build(), gatewayAllocator);
+                .build(), gatewayAllocator, snapshotsInfoService);
 
         logger.info("Building initial routing table");
 
@@ -81,7 +84,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(10).numberOfReplicas(1))
                 .build();
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
         logger.info("start one node, do reroute, only 3 should initialize");
         clusterState = ClusterState.builder(clusterState).nodes(DiscoveryNodes.builder().add(newNode("node1"))).build();
@@ -122,12 +125,14 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
 
     public void testReplicaAndPrimaryRecoveryThrottling() {
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         AllocationService strategy = createAllocationService(Settings.builder()
                 .put("cluster.routing.allocation.node_concurrent_recoveries", 3)
                 .put("cluster.routing.allocation.concurrent_source_recoveries", 3)
                 .put("cluster.routing.allocation.node_initial_primaries_recoveries", 3)
                 .build(),
-            gatewayAllocator);
+            gatewayAllocator,
+            snapshotsInfoService);
 
         logger.info("Building initial routing table");
 
@@ -135,7 +140,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(5).numberOfReplicas(1))
                 .build();
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
         logger.info("with one node, do reroute, only 3 should initialize");
         clusterState = strategy.reroute(clusterState, "reroute");
@@ -184,19 +189,20 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
 
     public void testThrottleIncomingAndOutgoing() {
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         Settings settings = Settings.builder()
             .put("cluster.routing.allocation.node_concurrent_recoveries", 5)
             .put("cluster.routing.allocation.node_initial_primaries_recoveries", 5)
             .put("cluster.routing.allocation.cluster_concurrent_rebalance", 5)
             .build();
-        AllocationService strategy = createAllocationService(settings, gatewayAllocator);
+        AllocationService strategy = createAllocationService(settings, gatewayAllocator, snapshotsInfoService);
         logger.info("Building initial routing table");
 
         Metadata metadata = Metadata.builder()
             .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(9).numberOfReplicas(0))
             .build();
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
         logger.info("with one node, do reroute, only 5 should initialize");
         clusterState = strategy.reroute(clusterState, "reroute");
@@ -243,9 +249,10 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
 
     public void testOutgoingThrottlesAllocation() {
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         AllocationService strategy = createAllocationService(Settings.builder()
             .put("cluster.routing.allocation.node_concurrent_outgoing_recoveries", 1)
-            .build(), gatewayAllocator);
+            .build(), gatewayAllocator, snapshotsInfoService);
 
         logger.info("Building initial routing table");
 
@@ -253,7 +260,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
             .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
             .build();
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
         logger.info("with one node, do reroute, only 1 should initialize");
         clusterState = strategy.reroute(clusterState, "reroute");
@@ -314,7 +321,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
                 assertEquals("reached the limit of outgoing shard recoveries [1] on the node [node1] which holds the primary, "
                         + "cluster setting [cluster.routing.allocation.node_concurrent_outgoing_recoveries=1] "
                         + "(can also be set via [cluster.routing.allocation.node_concurrent_recoveries])",
-                        decision.getExplanation());
+                    decision.getExplanation());
                 assertEquals(Decision.Type.THROTTLE, decision.type());
                 foundThrottledMessage = true;
             }
@@ -331,7 +338,11 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
         assertEquals(clusterState.getRoutingNodes().getOutgoingRecoveries("node2"), 0);
     }
 
-    private ClusterState createRecoveryStateAndInitalizeAllocations(Metadata metadata, TestGatewayAllocator gatewayAllocator) {
+    private ClusterState createRecoveryStateAndInitializeAllocations(
+        final Metadata metadata,
+        final TestGatewayAllocator gatewayAllocator,
+        final TestSnapshotsInfoService snapshotsInfoService
+        ) {
         DiscoveryNode node1 = newNode("node1");
         Metadata.Builder metadataBuilder = new Metadata.Builder(metadata);
         RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
@@ -387,8 +398,12 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
             ImmutableOpenMap.Builder<ShardId, RestoreInProgress.ShardRestoreStatus> restoreShards = ImmutableOpenMap.builder();
             for (ShardRouting shard : routingTable.allShards()) {
                 if (shard.primary() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
-                    ShardId shardId = shard.shardId();
+                    final ShardId shardId = shard.shardId();
                     restoreShards.put(shardId, new RestoreInProgress.ShardRestoreStatus(node1.getId(), RestoreInProgress.State.INIT));
+                    // Also set the snapshot shard size
+                    final SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) shard.recoverySource();
+                    final long shardSize = randomNonNegativeLong();
+                    snapshotsInfoService.addSnapshotShardSize(recoverySource.snapshot(), recoverySource.index(), shardId, shardSize);
                 }
             }
 
@@ -421,4 +436,22 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
             gatewayAllocator.addKnownAllocation(started);
         }
     }
+
+    private static class TestSnapshotsInfoService implements SnapshotsInfoService {
+
+        private volatile ImmutableOpenMap<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes = ImmutableOpenMap.of();
+
+        synchronized void addSnapshotShardSize(Snapshot snapshot, IndexId index, ShardId shard, Long size) {
+            final ImmutableOpenMap.Builder<InternalSnapshotsInfoService.SnapshotShard, Long> newSnapshotShardSizes =
+                ImmutableOpenMap.builder(snapshotShardSizes);
+            boolean added = newSnapshotShardSizes.put(new InternalSnapshotsInfoService.SnapshotShard(snapshot, index, shard), size) == null;
+            assert added : "cannot add snapshot shard size twice";
+            this.snapshotShardSizes = newSnapshotShardSizes.build();
+        }
+
+        @Override
+        public SnapshotShardSizeInfo snapshotShardSizes() {
+            return new SnapshotShardSizeInfo(snapshotShardSizes);
+        }
+    }
 }
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDecidersTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDecidersTests.java
index 69b1e85b899..3816aef13a7 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDecidersTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDecidersTests.java
@@ -91,7 +91,7 @@ public class AllocationDecidersTests extends ESTestCase {
 
         ClusterState clusterState = ClusterState.builder(new ClusterName("test")).build();
         final RoutingAllocation allocation = new RoutingAllocation(deciders,
-            clusterState.getRoutingNodes(), clusterState, null, 0L);
+            clusterState.getRoutingNodes(), clusterState, null, null,0L);
 
         allocation.setDebugMode(mode);
         final UnassignedInfo unassignedInfo = new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "_message");
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java
index f49607aba16..3000336e88a 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java
@@ -51,6 +51,7 @@ import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.util.Arrays;
@@ -109,7 +110,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
             return clusterInfo;
         };
         AllocationService strategy = new AllocationService(deciders,
-                new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(1))
@@ -188,7 +189,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         clusterState = strategy.reroute(clusterState, "reroute");
         logShardStates(clusterState);
@@ -216,7 +217,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         clusterState = strategy.reroute(clusterState, "reroute");
 
@@ -285,7 +286,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         };
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
@@ -331,7 +332,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
             return clusterInfo2;
         };
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         clusterState = strategy.reroute(clusterState, "reroute");
         logShardStates(clusterState);
@@ -393,7 +394,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         clusterState = strategy.reroute(clusterState, "reroute");
         logShardStates(clusterState);
@@ -422,7 +423,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         clusterState = strategy.reroute(clusterState, "reroute");
 
@@ -517,7 +518,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         };
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
@@ -577,7 +578,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         };
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
@@ -671,7 +672,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         final ClusterInfoService cis = clusterInfoReference::get;
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-            new BalancedShardsAllocator(Settings.EMPTY), cis);
+            new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(1))
@@ -851,7 +852,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         );
         ClusterState clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
-                System.nanoTime());
+                null, System.nanoTime());
         routingAllocation.debugDecision(true);
         Decision decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.NO));
@@ -877,7 +878,8 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         )
         );
         clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
-        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, System.nanoTime());
+        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, null,
+            System.nanoTime());
         routingAllocation.debugDecision(true);
         decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.YES));
@@ -907,7 +909,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                 diskThresholdDecider
         )));
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
         // Ensure that the reroute call doesn't alter the routing table, since the first primary is relocating away
         // and therefor we will have sufficient disk space on node1.
         ClusterState result = strategy.reroute(clusterState, "reroute");
@@ -979,7 +981,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         );
         ClusterState clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
-                System.nanoTime());
+                null, System.nanoTime());
         routingAllocation.debugDecision(true);
         Decision decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
 
@@ -999,7 +1001,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         )));
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
         ClusterState result = strategy.reroute(clusterState, "reroute");
 
         assertThat(result.routingTable().index("test").getShards().get(0).primaryShard().state(), equalTo(STARTED));
@@ -1032,7 +1034,8 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         );
 
         clusterState = ClusterState.builder(updateClusterState).routingTable(builder.build()).build();
-        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, System.nanoTime());
+        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, null,
+            System.nanoTime());
         routingAllocation.debugDecision(true);
         decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.YES));
@@ -1096,7 +1099,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
             diskThresholdDecider
         )));
         AllocationService strategy = new AllocationService(deciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
         ClusterState result = strategy.reroute(clusterState, "reroute");
 
         ShardRouting shardRouting = result.routingTable().index("test").getShards().get(0).primaryShard();
@@ -1117,7 +1120,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         clusterState = ClusterState.builder(clusterState).routingTable(forceAssignedRoutingTable).build();
 
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
-            System.nanoTime());
+            null, System.nanoTime());
         routingAllocation.debugDecision(true);
         Decision decision = diskThresholdDecider.canRemain(startedShard, clusterState.getRoutingNodes().node("data"), routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.NO));
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java
index aad58f596fa..d0992a63f10 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java
@@ -106,7 +106,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         final ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(),
             mostAvailableUsage.build(), shardSizes.build(), ImmutableOpenMap.of(),  ImmutableOpenMap.of());
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, clusterInfo, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, clusterInfo, null, System.nanoTime());
         allocation.debugDecision(true);
         Decision decision = decider.canAllocate(test_0, new RoutingNode("node_0", node_0), allocation);
         assertEquals(mostAvailableUsage.toString(), Decision.Type.YES, decision.type());
@@ -161,7 +161,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(), mostAvailableUsage.build(),
             shardSizes.build(), ImmutableOpenMap.of(),  ImmutableOpenMap.of());
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, clusterInfo, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, clusterInfo, null, System.nanoTime());
         allocation.debugDecision(true);
         Decision decision = decider.canAllocate(test_0, new RoutingNode("node_0", node_0), allocation);
         assertEquals(Decision.Type.NO, decision.type());
@@ -242,7 +242,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         final ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(), mostAvailableUsage.build(),
             shardSizes.build(), shardRoutingMap.build(), ImmutableOpenMap.of());
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, clusterInfo, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, clusterInfo, null, System.nanoTime());
         allocation.debugDecision(true);
         Decision decision = decider.canRemain(test_0, new RoutingNode("node_0", node_0), allocation);
         assertEquals(Decision.Type.YES, decision.type());
@@ -296,7 +296,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         routingTableBuilder.addAsNew(metadata.index("other"));
         ClusterState clusterState = ClusterState.builder(org.elasticsearch.cluster.ClusterName.CLUSTER_NAME_SETTING
             .getDefault(Settings.EMPTY)).metadata(metadata).routingTable(routingTableBuilder.build()).build();
-        RoutingAllocation allocation = new RoutingAllocation(null, null, clusterState, info, 0);
+        RoutingAllocation allocation = new RoutingAllocation(null, null, clusterState, info, null, 0);
 
         final Index index = new Index("test", "1234");
         ShardRouting test_0 = ShardRouting.newUnassigned(new ShardId(index, 0), false, PeerRecoverySource.INSTANCE,
@@ -390,7 +390,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         clusterState = startShardsAndReroute(allocationService, clusterState,
             clusterState.getRoutingTable().index("test").shardsWithState(ShardRoutingState.UNASSIGNED));
 
-        RoutingAllocation allocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, info, 0);
+        RoutingAllocation allocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, info, null,0);
 
         final Index index = new Index("test", "1234");
         ShardRouting test_0 = ShardRouting.newUnassigned(new ShardId(index, 0), true,
@@ -435,14 +435,14 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
 
         allocationService.reroute(clusterState, "foo");
         RoutingAllocation allocationWithMissingSourceIndex = new RoutingAllocation(null,
-            clusterStateWithMissingSourceIndex.getRoutingNodes(), clusterStateWithMissingSourceIndex, info, 0);
+            clusterStateWithMissingSourceIndex.getRoutingNodes(), clusterStateWithMissingSourceIndex, info, null,0);
         assertEquals(42L, getExpectedShardSize(target, 42L, allocationWithMissingSourceIndex));
         assertEquals(42L, getExpectedShardSize(target2, 42L, allocationWithMissingSourceIndex));
     }
 
     private static long getExpectedShardSize(ShardRouting shardRouting, long defaultSize, RoutingAllocation allocation) {
         return DiskThresholdDecider.getExpectedShardSize(shardRouting, defaultSize,
-            allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+            allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(), allocation.routingTable());
     }
 
     public void testDiskUsageWithRelocations() {
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/EnableAllocationShortCircuitTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/EnableAllocationShortCircuitTests.java
index 9387ee40acf..6728a2e8892 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/EnableAllocationShortCircuitTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/EnableAllocationShortCircuitTests.java
@@ -37,6 +37,7 @@ import org.elasticsearch.cluster.routing.allocation.allocator.BalancedShardsAllo
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.plugins.ClusterPlugin;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.util.ArrayList;
@@ -166,7 +167,8 @@ public class EnableAllocationShortCircuitTests extends ESAllocationTestCase {
                 Collections.singletonList(plugin)));
         return new MockAllocationService(
             new AllocationDeciders(deciders),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
     }
 
     private static class RebalanceShortCircuitPlugin implements ClusterPlugin {
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/FilterAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/FilterAllocationDeciderTests.java
index 6b867917a11..7bf67839247 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/FilterAllocationDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/FilterAllocationDeciderTests.java
@@ -37,6 +37,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.IndexScopedSettings;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 import java.util.Arrays;
@@ -58,7 +59,8 @@ public class FilterAllocationDeciderTests extends ESAllocationTestCase {
                 new SameShardAllocationDecider(Settings.EMPTY, clusterSettings),
                 new ReplicaAfterPrimaryActiveAllocationDecider()));
         AllocationService service = new AllocationService(allocationDeciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         ClusterState state = createInitialClusterState(service, Settings.builder().put("index.routing.allocation.initial_recovery._id",
             "node2").build());
         RoutingTable routingTable = state.routingTable();
@@ -73,7 +75,7 @@ public class FilterAllocationDeciderTests extends ESAllocationTestCase {
 
         // after failing the shard we are unassigned since the node is blacklisted and we can't initialize on the other node
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         Decision.Single decision = (Decision.Single) filterAllocationDecider.canAllocate(
             routingTable.index("idx").shard(0).primaryShard(),
@@ -124,7 +126,7 @@ public class FilterAllocationDeciderTests extends ESAllocationTestCase {
         assertEquals(routingTable.index("idx").shard(0).primaryShard().currentNodeId(), "node1");
 
         allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         decision = (Decision.Single) filterAllocationDecider.canAllocate(
             routingTable.index("idx").shard(0).shards().get(0),
diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java
index 559f2617f0f..8f8c053e943 100644
--- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java
@@ -191,7 +191,7 @@ public class RestoreInProgressAllocationDeciderTests extends ESAllocationTestCas
     private Decision executeAllocation(final ClusterState clusterState, final ShardRouting shardRouting) {
         final AllocationDecider decider = new RestoreInProgressAllocationDecider();
         final RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, null, 0L);
+            clusterState.getRoutingNodes(), clusterState, null, null, 0L);
         allocation.debugDecision(true);
 
         final Decision decision;
diff --git a/server/src/test/java/org/elasticsearch/gateway/GatewayServiceTests.java b/server/src/test/java/org/elasticsearch/gateway/GatewayServiceTests.java
index 436e0c85c96..1667d6b09b3 100644
--- a/server/src/test/java/org/elasticsearch/gateway/GatewayServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/gateway/GatewayServiceTests.java
@@ -39,6 +39,7 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.hamcrest.Matchers;
@@ -60,7 +61,8 @@ public class GatewayServiceTests extends ESTestCase {
         final AllocationService allocationService = new AllocationService(new AllocationDeciders(new HashSet<>(
             Arrays.asList(new SameShardAllocationDecider(Settings.EMPTY, new ClusterSettings(Settings.EMPTY,
                 ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), new ReplicaAfterPrimaryActiveAllocationDecider()))),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         return new GatewayService(settings.build(), allocationService, clusterService, null, null, null);
     }
 
diff --git a/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java b/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java
index 2404facdda7..8555ec1ea14 100644
--- a/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java
+++ b/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java
@@ -44,6 +44,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.env.ShardLockObtainFailedException;
@@ -51,6 +52,7 @@ import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.junit.Before;
 
 import java.util.Arrays;
@@ -341,7 +343,7 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * deciders say yes, we allocate to that node.
      */
     public void testRestore() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), "allocId");
+        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), randomLong(), "allocId");
         testAllocator.addData(node1, "some allocId", randomBoolean());
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(true));
@@ -355,7 +357,7 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * deciders say throttle, we add it to ignored shards.
      */
     public void testRestoreThrottle() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(throttleAllocationDeciders(), "allocId");
+        RoutingAllocation allocation = getRestoreRoutingAllocation(throttleAllocationDeciders(), randomLong(), "allocId");
         testAllocator.addData(node1, "some allocId", randomBoolean());
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(true));
@@ -368,12 +370,15 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * deciders say no, we still allocate to that node.
      */
     public void testRestoreForcesAllocateIfShardAvailable() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(noAllocationDeciders(), "allocId");
+        final long shardSize = randomNonNegativeLong();
+        RoutingAllocation allocation = getRestoreRoutingAllocation(noAllocationDeciders(), shardSize, "allocId");
         testAllocator.addData(node1, "some allocId", randomBoolean());
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(true));
         assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(true));
-        assertThat(allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING).size(), equalTo(1));
+        final List<ShardRouting> initializingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING);
+        assertThat(initializingShards.size(), equalTo(1));
+        assertThat(initializingShards.get(0).getExpectedShardSize(), equalTo(shardSize));
         assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
     }
 
@@ -382,8 +387,8 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * the unassigned list to be allocated later.
      */
     public void testRestoreDoesNotAssignIfNoShardAvailable() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), "allocId");
-        testAllocator.addData(node1, null, false);
+        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), randomNonNegativeLong(), "allocId");
+        testAllocator.addData(node1, null, randomBoolean());
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(false));
         assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(true));
@@ -391,7 +396,22 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
         assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
     }
 
-    private RoutingAllocation getRestoreRoutingAllocation(AllocationDeciders allocationDeciders, String... allocIds) {
+    /**
+     * Tests that when restoring from a snapshot and we don't know the shard size yet, the shard will remain in
+     * the unassigned list to be allocated later.
+     */
+    public void testRestoreDoesNotAssignIfShardSizeNotAvailable() {
+        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), null, "allocId");
+        testAllocator.addData(node1, null, false);
+        allocateAllUnassigned(allocation);
+        assertThat(allocation.routingNodesChanged(), equalTo(true));
+        assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(false));
+        ShardRouting ignoredRouting = allocation.routingNodes().unassigned().ignored().get(0);
+        assertThat(ignoredRouting.unassignedInfo().getLastAllocationStatus(), equalTo(AllocationStatus.FETCHING_SHARD_DATA));
+        assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
+    }
+
+    private RoutingAllocation getRestoreRoutingAllocation(AllocationDeciders allocationDeciders, Long shardSize, String... allocIds) {
         Metadata metadata = Metadata.builder()
             .put(IndexMetadata.builder(shardId.getIndexName()).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0)
                 .putInSyncAllocationIds(0, Sets.newHashSet(allocIds)))
@@ -407,7 +427,13 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
             .metadata(metadata)
             .routingTable(routingTable)
             .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(allocationDeciders, new RoutingNodes(state, false), state, null, System.nanoTime());
+        return new RoutingAllocation(allocationDeciders, new RoutingNodes(state, false), state, null,
+            new SnapshotShardSizeInfo(ImmutableOpenMap.of()) {
+                @Override
+                public Long getShardSize(ShardRouting shardRouting) {
+                    return shardSize;
+                }
+            }, System.nanoTime());
     }
 
     private RoutingAllocation routingAllocationWithOnePrimaryNoReplicas(AllocationDeciders deciders, UnassignedInfo.Reason reason,
@@ -435,7 +461,7 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
                 .metadata(metadata)
                 .routingTable(routingTableBuilder.build())
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, null, System.nanoTime());
+        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, null, null, System.nanoTime());
     }
 
     private void assertClusterHealthStatus(RoutingAllocation allocation, ClusterHealthStatus expectedStatus) {
diff --git a/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java b/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java
index 2432fb01673..a2b35f7a30a 100644
--- a/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java
+++ b/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java
@@ -54,6 +54,7 @@ import org.elasticsearch.index.store.Store;
 import org.elasticsearch.index.store.StoreFileMetadata;
 import org.elasticsearch.indices.store.TransportNodesListShardStoreMetadata;
 import org.elasticsearch.cluster.ESAllocationTestCase;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.junit.Before;
 
 import java.util.ArrayList;
@@ -473,7 +474,8 @@ public class ReplicaShardAllocatorTests extends ESAllocationTestCase {
                 .metadata(metadata)
                 .routingTable(routingTable)
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, System.nanoTime());
+        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY,
+            System.nanoTime());
     }
 
     private RoutingAllocation onePrimaryOnNode1And1ReplicaRecovering(AllocationDeciders deciders, UnassignedInfo unassignedInfo) {
@@ -496,7 +498,8 @@ public class ReplicaShardAllocatorTests extends ESAllocationTestCase {
                 .metadata(metadata)
                 .routingTable(routingTable)
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, System.nanoTime());
+        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY,
+            System.nanoTime());
     }
 
     private RoutingAllocation onePrimaryOnNode1And1ReplicaRecovering(AllocationDeciders deciders) {
diff --git a/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java b/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java
index f24d2fbfcbc..f9892ec35ee 100644
--- a/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java
+++ b/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java
@@ -91,6 +91,7 @@ import org.elasticsearch.index.shard.IndexEventListener;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.ShardLimitValidator;
 import org.elasticsearch.indices.SystemIndices;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.Transport;
@@ -146,7 +147,7 @@ public class ClusterStateChanges {
                 new ReplicaAfterPrimaryActiveAllocationDecider(),
                 new RandomAllocationDeciderTests.RandomAllocationDecider(getRandom())))),
             new TestGatewayAllocator(), new BalancedShardsAllocator(SETTINGS),
-            EmptyClusterInfoService.INSTANCE);
+            EmptyClusterInfoService.INSTANCE, EmptySnapshotsInfoService.INSTANCE);
         shardFailedClusterStateTaskExecutor
             = new ShardStateAction.ShardFailedClusterStateTaskExecutor(allocationService, null, () -> Priority.NORMAL, logger);
         shardStartedClusterStateTaskExecutor
diff --git a/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java b/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java
new file mode 100644
index 00000000000..268f3054693
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java
@@ -0,0 +1,350 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.IndexRoutingTable;
+import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.RerouteService;
+import org.elasticsearch.cluster.routing.RoutingTable;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
+import org.elasticsearch.cluster.service.ClusterApplier;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Priority;
+import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus;
+import org.elasticsearch.repositories.FilterRepository;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.repositories.RepositoriesService;
+import org.elasticsearch.repositories.Repository;
+import org.elasticsearch.test.ClusterServiceUtils;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.threadpool.ThreadPoolStats;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.Collections;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_CREATION_DATE;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED;
+import static org.elasticsearch.snapshots.InternalSnapshotsInfoService.INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING;
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.lessThan;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class InternalSnapshotsInfoServiceTests extends ESTestCase {
+
+    private TestThreadPool threadPool;
+    private ClusterService clusterService;
+    private RepositoriesService repositoriesService;
+    private RerouteService rerouteService;
+
+    @Before
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        threadPool = new TestThreadPool(getTestName());
+        clusterService = ClusterServiceUtils.createClusterService(threadPool);
+        repositoriesService = mock(RepositoriesService.class);
+        rerouteService = mock(RerouteService.class);
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            final ActionListener<ClusterState> listener = (ActionListener<ClusterState>) invocation.getArguments()[2];
+            listener.onResponse(clusterService.state());
+            return null;
+        }).when(rerouteService).reroute(anyString(), any(Priority.class), any());
+    }
+
+    @After
+    @Override
+    public void tearDown() throws Exception {
+        super.tearDown();
+        final boolean terminated = terminate(threadPool);
+        assert terminated;
+        clusterService.close();
+    }
+
+    public void testSnapshotShardSizes() throws Exception {
+        final int maxConcurrentFetches = randomIntBetween(1, 10);
+        final InternalSnapshotsInfoService snapshotsInfoService =
+            new InternalSnapshotsInfoService(Settings.builder()
+                .put(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.getKey(), maxConcurrentFetches)
+                .build(), clusterService, () -> repositoriesService, () -> rerouteService);
+
+        final int numberOfShards = randomIntBetween(1, 50);
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final long[] expectedShardSizes = new long[numberOfShards];
+        for (int i = 0; i < expectedShardSizes.length; i++) {
+            expectedShardSizes[i] = randomNonNegativeLong();
+        }
+
+        final AtomicInteger getShardSnapshotStatusCount = new AtomicInteger(0);
+        final CountDownLatch latch = new CountDownLatch(1);
+        final Repository mockRepository = new FilterRepository(mock(Repository.class)) {
+            @Override
+            public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId shardId) {
+                try {
+                    assertThat(indexId.getName(), equalTo(indexName));
+                    assertThat(shardId.id(), allOf(greaterThanOrEqualTo(0), lessThan(numberOfShards)));
+                    latch.await();
+                    getShardSnapshotStatusCount.incrementAndGet();
+                    return IndexShardSnapshotStatus.newDone(0L, 0L, 0, 0, 0L, expectedShardSizes[shardId.id()], null);
+                } catch (InterruptedException e) {
+                    throw new AssertionError(e);
+                }
+            }
+        };
+        when(repositoriesService.repository("_repo")).thenReturn(mockRepository);
+
+        applyClusterState("add-unassigned-shards", clusterState -> addUnassignedShards(clusterState, indexName, numberOfShards));
+        waitForMaxActiveGenericThreads(Math.min(numberOfShards, maxConcurrentFetches));
+
+        if (randomBoolean()) {
+            applyClusterState("reapply-last-cluster-state-to-check-deduplication-works",
+                state -> ClusterState.builder(state).incrementVersion().build());
+        }
+
+        assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(numberOfShards));
+        assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(0));
+
+        latch.countDown();
+
+        assertBusy(() -> {
+            assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(numberOfShards));
+            assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0));
+            assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(), equalTo(0));
+        });
+        verify(rerouteService, times(numberOfShards)).reroute(anyString(), any(Priority.class), any());
+        assertThat(getShardSnapshotStatusCount.get(), equalTo(numberOfShards));
+
+        for (int i = 0; i < numberOfShards; i++) {
+            final ShardRouting shardRouting = clusterService.state().routingTable().index(indexName).shard(i).primaryShard();
+            assertThat(snapshotsInfoService.snapshotShardSizes().getShardSize(shardRouting), equalTo(expectedShardSizes[i]));
+        }
+    }
+
+    public void testErroneousSnapshotShardSizes() throws Exception {
+        final InternalSnapshotsInfoService snapshotsInfoService =
+            new InternalSnapshotsInfoService(Settings.builder()
+                .put(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.getKey(), randomIntBetween(1, 10))
+                .build(), clusterService, () -> repositoriesService, () -> rerouteService);
+
+        final Map<InternalSnapshotsInfoService.SnapshotShard, Boolean> results = new ConcurrentHashMap<>();
+        final Repository mockRepository = new FilterRepository(mock(Repository.class)) {
+            @Override
+            public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId shardId) {
+                final InternalSnapshotsInfoService.SnapshotShard snapshotShard =
+                    new InternalSnapshotsInfoService.SnapshotShard(new Snapshot("_repo", snapshotId), indexId, shardId);
+                if (randomBoolean()) {
+                    results.put(snapshotShard, Boolean.FALSE);
+                    throw new SnapshotException(snapshotShard.snapshot(), "simulated");
+                } else {
+                    results.put(snapshotShard, Boolean.TRUE);
+                    return IndexShardSnapshotStatus.newDone(0L, 0L, 0, 0, 0L, randomNonNegativeLong(), null);
+                }
+            }
+        };
+        when(repositoriesService.repository("_repo")).thenReturn(mockRepository);
+
+        final int maxShardsToCreate = scaledRandomIntBetween(10, 500);
+        final Thread addSnapshotRestoreIndicesThread = new Thread(() -> {
+            int remainingShards = maxShardsToCreate;
+            while (remainingShards > 0) {
+                final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+                final int numberOfShards = randomIntBetween(1, remainingShards);
+                try {
+                    applyClusterState("add-more-unassigned-shards-for-" + indexName,
+                        clusterState -> addUnassignedShards(clusterState, indexName, numberOfShards));
+                } catch (Exception e) {
+                    throw new AssertionError(e);
+                } finally {
+                    remainingShards -= numberOfShards;
+                }
+            }
+        });
+        addSnapshotRestoreIndicesThread.start();
+        addSnapshotRestoreIndicesThread.join();
+
+        assertBusy(() -> {
+            assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(),
+                equalTo((int) results.values().stream().filter(result -> result.equals(Boolean.TRUE)).count()));
+            assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(),
+                equalTo((int) results.values().stream().filter(result -> result.equals(Boolean.FALSE)).count()));
+            assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0));
+        });
+    }
+
+    public void testNoLongerMaster() throws Exception {
+        final InternalSnapshotsInfoService snapshotsInfoService =
+            new InternalSnapshotsInfoService(Settings.EMPTY, clusterService, () -> repositoriesService, () -> rerouteService);
+
+        final Repository mockRepository = new FilterRepository(mock(Repository.class)) {
+            @Override
+            public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId shardId) {
+                return IndexShardSnapshotStatus.newDone(0L, 0L, 0, 0, 0L, randomNonNegativeLong(), null);
+            }
+        };
+        when(repositoriesService.repository("_repo")).thenReturn(mockRepository);
+
+        for (int i = 0; i < randomIntBetween(1, 10); i++) {
+            final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+            final int nbShards =  randomIntBetween(1, 5);
+            applyClusterState("restore-indices-when-master-" + indexName,
+                clusterState -> addUnassignedShards(clusterState, indexName, nbShards));
+        }
+
+        applyClusterState("demote-current-master", this::demoteMasterNode);
+
+        for (int i = 0; i < randomIntBetween(1, 10); i++) {
+            final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+            final int nbShards =  randomIntBetween(1, 5);
+            applyClusterState("restore-indices-when-no-longer-master-" + indexName,
+                clusterState -> addUnassignedShards(clusterState, indexName, nbShards));
+        }
+
+        assertBusy(() -> {
+            assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(0));
+            assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0));
+            assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(), equalTo(0));
+        });
+    }
+
+    private void applyClusterState(final String reason, final Function<ClusterState, ClusterState> applier) {
+        PlainActionFuture.get(future -> clusterService.getClusterApplierService().onNewClusterState(reason,
+            () -> applier.apply(clusterService.state()),
+            new ClusterApplier.ClusterApplyListener() {
+                @Override
+                public void onSuccess(String source) {
+                    future.onResponse(source);
+                }
+
+                @Override
+                public void onFailure(String source, Exception e) {
+                    future.onFailure(e);
+                }
+            })
+        );
+    }
+
+    private void waitForMaxActiveGenericThreads(final int nbActive) throws Exception {
+        assertBusy(() -> {
+            final ThreadPoolStats threadPoolStats = clusterService.getClusterApplierService().threadPool().stats();
+            ThreadPoolStats.Stats generic = null;
+            for (ThreadPoolStats.Stats threadPoolStat : threadPoolStats) {
+                if (ThreadPool.Names.GENERIC.equals(threadPoolStat.getName())) {
+                    generic = threadPoolStat;
+                }
+            }
+            assertThat(generic, notNullValue());
+            assertThat(generic.getActive(), equalTo(nbActive));
+        }, 30L, TimeUnit.SECONDS);
+    }
+
+    private ClusterState addUnassignedShards(final ClusterState currentState, String indexName, int numberOfShards) {
+        assertThat(currentState.metadata().hasIndex(indexName), is(false));
+
+        final Metadata.Builder metadata = Metadata.builder(currentState.metadata())
+            .put(IndexMetadata.builder(indexName)
+                .settings(Settings.builder()
+                    .put(SETTING_VERSION_CREATED, Version.CURRENT)
+                    .put(SETTING_NUMBER_OF_SHARDS, numberOfShards)
+                    .put(SETTING_NUMBER_OF_REPLICAS, randomIntBetween(0, 1))
+                    .put(SETTING_CREATION_DATE, System.currentTimeMillis()))
+                .build(), true)
+            .generateClusterUuidIfNeeded();
+
+        final RecoverySource.SnapshotRecoverySource recoverySource = new RecoverySource.SnapshotRecoverySource(
+            UUIDs.randomBase64UUID(random()),
+            new Snapshot("_repo", new SnapshotId(randomAlphaOfLength(5), UUIDs.randomBase64UUID(random()))),
+            Version.CURRENT,
+            new IndexId(indexName, UUIDs.randomBase64UUID(random()))
+        );
+
+        final Index index = metadata.get(indexName).getIndex();
+        final IndexRoutingTable.Builder indexRoutingTable = IndexRoutingTable.builder(index);
+        for (int primary = 0; primary < numberOfShards; primary++) {
+            final ShardId shardId = new ShardId(index, primary);
+
+            final IndexShardRoutingTable.Builder indexShards = new IndexShardRoutingTable.Builder(shardId);
+            indexShards.addShard(TestShardRouting.newShardRouting(shardId, null, true, ShardRoutingState.UNASSIGNED, recoverySource));
+            for (int replica = 0; replica < metadata.get(indexName).getNumberOfReplicas(); replica++) {
+                indexShards.addShard(TestShardRouting.newShardRouting(shardId, null, false, ShardRoutingState.UNASSIGNED,
+                    RecoverySource.PeerRecoverySource.INSTANCE));
+            }
+            indexRoutingTable.addIndexShard(indexShards.build());
+        }
+
+        final RoutingTable.Builder routingTable = RoutingTable.builder(currentState.routingTable());
+        routingTable.add(indexRoutingTable.build());
+
+        return ClusterState.builder(currentState)
+            .routingTable(routingTable.build())
+            .metadata(metadata)
+            .build();
+    }
+
+    private ClusterState demoteMasterNode(final ClusterState currentState) {
+        final DiscoveryNode node = new DiscoveryNode("other", ESTestCase.buildNewFakeTransportAddress(), Collections.emptyMap(),
+            DiscoveryNodeRole.BUILT_IN_ROLES, Version.CURRENT);
+        assertThat(currentState.nodes().get(node.getId()), nullValue());
+
+        return ClusterState.builder(currentState)
+            .nodes(DiscoveryNodes.builder(currentState.nodes())
+                .add(node)
+                .masterNodeId(node.getId()))
+            .build();
+    }
+}
diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
index 826ac539d70..5e2546c4c95 100644
--- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
+++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
@@ -65,7 +65,6 @@ import org.elasticsearch.action.admin.indices.mapping.put.TransportPutMappingAct
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresAction;
 import org.elasticsearch.action.admin.indices.shards.TransportIndicesShardStoresAction;
 import org.elasticsearch.action.bulk.BulkAction;
-import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.action.bulk.BulkRequest;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.bulk.TransportBulkAction;
@@ -125,6 +124,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.BatchedRerouteService;
+import org.elasticsearch.cluster.routing.RerouteService;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
@@ -152,6 +152,7 @@ import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.gateway.MetaStateService;
 import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards;
 import org.elasticsearch.index.Index;
+import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.index.analysis.AnalysisRegistry;
 import org.elasticsearch.index.seqno.GlobalCheckpointSyncAction;
 import org.elasticsearch.index.seqno.RetentionLeaseSyncer;
@@ -1390,6 +1391,8 @@ public class SnapshotResiliencyTests extends ESTestCase {
 
             private final AllocationService allocationService;
 
+            private final RerouteService rerouteService;
+
             private final NodeClient client;
 
             private final NodeEnvironment nodeEnv;
@@ -1490,7 +1493,12 @@ public class SnapshotResiliencyTests extends ESTestCase {
                 final NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList());
                 final ScriptService scriptService = new ScriptService(settings, emptyMap(), emptyMap());
                 client = new NodeClient(settings, threadPool);
-                allocationService = ESAllocationTestCase.createAllocationService(settings);
+                final SetOnce<RerouteService> rerouteServiceSetOnce = new SetOnce<>();
+                final SnapshotsInfoService snapshotsInfoService = new InternalSnapshotsInfoService(settings, clusterService,
+                    () -> repositoriesService, rerouteServiceSetOnce::get);
+                allocationService = ESAllocationTestCase.createAllocationService(settings, snapshotsInfoService);
+                rerouteService = new BatchedRerouteService(clusterService, allocationService::reroute);
+                rerouteServiceSetOnce.set(rerouteService);
                 final IndexScopedSettings indexScopedSettings =
                     new IndexScopedSettings(settings, IndexScopedSettings.BUILT_IN_INDEX_SETTINGS);
                 final BigArrays bigArrays = new BigArrays(new PageCacheRecycler(settings), null, "test");
@@ -1521,11 +1529,8 @@ public class SnapshotResiliencyTests extends ESTestCase {
                 final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
                 snapshotShardsService =
                         new SnapshotShardsService(settings, clusterService, repositoriesService, transportService, indicesService);
-                final ShardStateAction shardStateAction = new ShardStateAction(
-                    clusterService, transportService, allocationService,
-                    new BatchedRerouteService(clusterService, allocationService::reroute),
-                    threadPool
-                );
+                final ShardStateAction shardStateAction =
+                    new ShardStateAction(clusterService, transportService, allocationService, rerouteService, threadPool);
                 nodeConnectionsService =
                     new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService);
                 final MetadataMappingService metadataMappingService = new MetadataMappingService(clusterService, indicesService);
@@ -1721,7 +1726,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
                     hostsResolver -> nodes.values().stream().filter(n -> n.node.isMasterNode())
                         .map(n -> n.node.getAddress()).collect(Collectors.toList()),
                     clusterService.getClusterApplierService(), Collections.emptyList(), random(),
-                    new BatchedRerouteService(clusterService, allocationService::reroute), ElectionStrategy.DEFAULT_INSTANCE,
+                    rerouteService, ElectionStrategy.DEFAULT_INSTANCE,
                     () -> new StatusInfo(HEALTHY, "healthy-info"));
                 masterService.setClusterStatePublisher(coordinator);
                 coordinator.start();
diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java
index 4deffc749dd..30c05e75a04 100644
--- a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java
+++ b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java
@@ -22,6 +22,7 @@ package org.elasticsearch.cluster;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
@@ -34,9 +35,12 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationDecider;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.gateway.GatewayAllocator;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
@@ -56,6 +60,16 @@ public abstract class ESAllocationTestCase extends ESTestCase {
     private static final ClusterSettings EMPTY_CLUSTER_SETTINGS =
         new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
 
+    public static final SnapshotsInfoService SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES = () ->
+        new SnapshotShardSizeInfo(ImmutableOpenMap.of()) {
+            @Override
+            public Long getShardSize(ShardRouting shardRouting) {
+                assert shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT :
+                    "Expecting a recovery source of type [SNAPSHOT] but got [" + shardRouting.recoverySource().getType() + ']';
+                throw new UnsupportedOperationException();
+            }
+    };
+
     public static MockAllocationService createAllocationService() {
         return createAllocationService(Settings.Builder.EMPTY_SETTINGS);
     }
@@ -71,19 +85,33 @@ public abstract class ESAllocationTestCase extends ESTestCase {
     public static MockAllocationService createAllocationService(Settings settings, ClusterSettings clusterSettings, Random random) {
         return new MockAllocationService(
                 randomAllocationDeciders(settings, clusterSettings, random),
-                new TestGatewayAllocator(), new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE);
+                new TestGatewayAllocator(), new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE,
+            SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES);
     }
 
     public static MockAllocationService createAllocationService(Settings settings, ClusterInfoService clusterInfoService) {
         return new MockAllocationService(
                 randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(settings), clusterInfoService);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(settings), clusterInfoService,
+            SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES);
     }
 
     public static MockAllocationService createAllocationService(Settings settings, GatewayAllocator gatewayAllocator) {
+        return createAllocationService(settings, gatewayAllocator, SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES);
+    }
+
+    public static MockAllocationService createAllocationService(Settings settings, SnapshotsInfoService snapshotsInfoService) {
+        return createAllocationService(settings, new TestGatewayAllocator(), snapshotsInfoService);
+    }
+
+    public static MockAllocationService createAllocationService(
+        Settings settings,
+        GatewayAllocator gatewayAllocator,
+        SnapshotsInfoService snapshotsInfoService
+    ) {
         return new MockAllocationService(
-                randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
-                gatewayAllocator, new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE);
+            randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
+            gatewayAllocator, new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE, snapshotsInfoService);
     }
 
     public static AllocationDeciders randomAllocationDeciders(Settings settings, ClusterSettings clusterSettings, Random random) {
@@ -231,8 +259,9 @@ public abstract class ESAllocationTestCase extends ESTestCase {
         private volatile long nanoTimeOverride = -1L;
 
         public MockAllocationService(AllocationDeciders allocationDeciders, GatewayAllocator gatewayAllocator,
-                                     ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService) {
-            super(allocationDeciders, gatewayAllocator, shardsAllocator, clusterInfoService);
+                                     ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService,
+                                     SnapshotsInfoService snapshotsInfoService) {
+            super(allocationDeciders, gatewayAllocator, shardsAllocator, clusterInfoService, snapshotsInfoService);
         }
 
         public void setNanoTimeOverride(long nanoTime) {
diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java b/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java
index 9ce0310edff..7e61a3d73ea 100644
--- a/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java
+++ b/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java
@@ -47,7 +47,8 @@ public class MockInternalClusterInfoService extends InternalClusterInfoService {
     @Nullable // if no fakery should take place
     private volatile BiFunction<DiscoveryNode, FsInfo.Path, FsInfo.Path> diskUsageFunction;
 
-    public MockInternalClusterInfoService(Settings settings, ClusterService clusterService, ThreadPool threadPool, NodeClient client) {
+    public MockInternalClusterInfoService(Settings settings, ClusterService clusterService,
+                                          ThreadPool threadPool, NodeClient client) {
         super(settings, clusterService, threadPool, client);
     }
 
diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java
index 363eb7acc8a..4d043c60195 100644
--- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java
+++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java
@@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.apache.lucene.index.IndexCommit;
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
@@ -19,6 +20,8 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
+import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
+import org.elasticsearch.action.admin.indices.stats.ShardStats;
 import org.elasticsearch.action.support.ListenerTimeouts;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.ThreadedActionListener;
@@ -31,6 +34,7 @@ import org.elasticsearch.cluster.metadata.MappingMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.RepositoryMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
@@ -438,8 +442,23 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
     }
 
     @Override
-    public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId leaderShardId) {
-        throw new UnsupportedOperationException("Unsupported for repository of type: " + TYPE);
+    public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId index, ShardId shardId) {
+        assert SNAPSHOT_ID.equals(snapshotId) : "RemoteClusterRepository only supports " + SNAPSHOT_ID + " as the SnapshotId";
+        final String leaderIndex = index.getName();
+        final IndicesStatsResponse response = getRemoteClusterClient().admin().indices().prepareStats(leaderIndex)
+            .clear().setStore(true)
+            .get(ccrSettings.getRecoveryActionTimeout());
+        for (ShardStats shardStats : response.getIndex(leaderIndex).getShards()) {
+            final ShardRouting shardRouting = shardStats.getShardRouting();
+            if (shardRouting.shardId().id() == shardId.getId()
+                && shardRouting.primary()
+                && shardRouting.active()) {
+                // we only care about the shard size here for shard allocation, populate the rest with dummy values
+                final long totalSize = shardStats.getStats().getStore().getSizeInBytes();
+                return IndexShardSnapshotStatus.newDone(0L, 0L, 1, 1, totalSize, totalSize, "");
+            }
+        }
+        throw new ElasticsearchException("Could not get shard stats for primary of index " + leaderIndex + " on leader cluster");
     }
 
     @Override
diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/allocation/CcrPrimaryFollowerAllocationDeciderTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/allocation/CcrPrimaryFollowerAllocationDeciderTests.java
index e585f9f44d2..9cd8d190a6f 100644
--- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/allocation/CcrPrimaryFollowerAllocationDeciderTests.java
+++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/allocation/CcrPrimaryFollowerAllocationDeciderTests.java
@@ -52,6 +52,7 @@ import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.xpack.ccr.CcrSettings;
 
@@ -198,7 +199,7 @@ public class CcrPrimaryFollowerAllocationDeciderTests extends ESAllocationTestCa
     static Decision executeAllocation(ClusterState clusterState, ShardRouting shardRouting, DiscoveryNode node) {
         final AllocationDecider decider = new CcrPrimaryFollowerAllocationDecider();
         final RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.singletonList(decider)),
-            new RoutingNodes(clusterState), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         routingAllocation.debugDecision(true);
         return decider.canAllocate(shardRouting, new RoutingNode(node.getId(), node), routingAllocation);
     }
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/AllocationRoutedStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/AllocationRoutedStep.java
index 0f96faf43c0..461932dcacb 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/AllocationRoutedStep.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/AllocationRoutedStep.java
@@ -71,7 +71,7 @@ public class AllocationRoutedStep extends ClusterStateWaitStep {
         // All the allocation attributes are already set so just need to check
         // if the allocation has happened
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, clusterState.getRoutingNodes(), clusterState, null,
-            System.nanoTime());
+                null, System.nanoTime());
 
         int allocationPendingAllShards = 0;
 
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SetSingleNodeAllocateStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SetSingleNodeAllocateStep.java
index bc4bed96173..bb05c3707c7 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SetSingleNodeAllocateStep.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SetSingleNodeAllocateStep.java
@@ -65,7 +65,7 @@ public class SetSingleNodeAllocateStep extends AsyncActionStep {
     public void performAction(IndexMetadata indexMetadata, ClusterState clusterState, ClusterStateObserver observer, Listener listener) {
         final RoutingNodes routingNodes = clusterState.getRoutingNodes();
         RoutingAllocation allocation = new RoutingAllocation(ALLOCATION_DECIDERS, routingNodes, clusterState, null,
-                System.nanoTime());
+                null, System.nanoTime());
         List<String> validNodeIds = new ArrayList<>();
         String indexName = indexMetadata.getIndex().getName();
         final Map<ShardId, List<ShardRouting>> routingsByShardId = clusterState.getRoutingTable()
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java
index a7ea16023ba..8a9487b44db 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java
@@ -30,6 +30,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.xpack.core.DataTier;
 
@@ -58,7 +59,8 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
             new SameShardAllocationDecider(Settings.EMPTY, clusterSettings),
             new ReplicaAfterPrimaryActiveAllocationDecider()));
     private final AllocationService service = new AllocationService(allocationDeciders,
-        new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+        new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+        EmptySnapshotsInfoService.INSTANCE);
 
     private final ShardRouting shard = ShardRouting.newUnassigned(new ShardId("myindex", "myindex", 0), true,
         RecoverySource.EmptyStoreRecoverySource.INSTANCE, new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "index created"));
@@ -74,7 +76,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
     public void testClusterRequires() {
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_REQUIRE, "data_hot")
@@ -108,7 +110,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
     public void testClusterIncludes() {
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_INCLUDE, "data_warm,data_cold")
@@ -143,7 +145,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
     public void testClusterExcludes() {
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_EXCLUDE, "data_warm")
@@ -181,7 +183,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_REQUIRE, "data_hot")
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -213,7 +215,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_INCLUDE, "data_warm,data_cold")
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -247,7 +249,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_EXCLUDE, "data_warm,data_cold")
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null,0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -292,8 +294,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                 .build())
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -328,7 +329,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                 .build())
             .build();
-        allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, 0);
+        allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
 
         for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE)) {
@@ -376,8 +377,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                 .build())
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -439,8 +439,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                 .build())
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -502,8 +501,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                 .build())
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         Decision d;
         RoutingNode node;
@@ -553,7 +551,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_INCLUDE, "data_warm,data_cold")
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null,0);
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_EXCLUDE, "data_cold")
             .build());
diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java
index e5d4e9c5fea..c810eb975ed 100644
--- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java
+++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java
@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.searchablesnapshots;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocationDecision;
 import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
@@ -84,6 +85,11 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
             allocation.metadata().getIndexSafe(shardRouting.index()).getSettings()
         ).equals(ALLOCATOR_NAME);
 
+        if (shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT
+            && allocation.snapshotShardSizeInfo().getShardSize(shardRouting) == null) {
+            return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, null);
+        }
+
         // let BalancedShardsAllocator take care of allocating this shard
         // TODO: once we have persistent cache, choose a node that has existing data
         return AllocateUnassignedDecision.NOT_TAKEN;