diff --git a/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java b/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java index 2a9bd2c106d..06dcf16391a 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java @@ -20,7 +20,6 @@ package org.elasticsearch.snapshots; import org.elasticsearch.Version; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -38,7 +37,6 @@ import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.TestShardRouting; import org.elasticsearch.cluster.service.ClusterApplier; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Priority; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.Index; @@ -78,12 +76,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class InternalSnapshotsInfoServiceTests extends ESTestCase { @@ -100,13 +93,7 @@ public class InternalSnapshotsInfoServiceTests extends ESTestCase { threadPool = new TestThreadPool(getTestName()); clusterService = ClusterServiceUtils.createClusterService(threadPool); repositoriesService = mock(RepositoriesService.class); - rerouteService = mock(RerouteService.class); - doAnswer(invocation -> { - @SuppressWarnings("unchecked") - final ActionListener listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(clusterService.state()); - return null; - }).when(rerouteService).reroute(anyString(), any(Priority.class), any()); + rerouteService = (reason, priority, listener) -> listener.onResponse(clusterService.state()); } @After @@ -120,12 +107,20 @@ public class InternalSnapshotsInfoServiceTests extends ESTestCase { public void testSnapshotShardSizes() throws Exception { final int maxConcurrentFetches = randomIntBetween(1, 10); + + final int numberOfShards = randomIntBetween(1, 50); + final CountDownLatch rerouteLatch = new CountDownLatch(numberOfShards); + final RerouteService rerouteService = (reason, priority, listener) -> { + listener.onResponse(clusterService.state()); + assertThat(rerouteLatch.getCount(), greaterThanOrEqualTo(0L)); + rerouteLatch.countDown(); + }; + final InternalSnapshotsInfoService snapshotsInfoService = new InternalSnapshotsInfoService(Settings.builder() .put(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.getKey(), maxConcurrentFetches) .build(), clusterService, () -> repositoriesService, () -> rerouteService); - final int numberOfShards = randomIntBetween(1, 50); final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); final long[] expectedShardSizes = new long[numberOfShards]; for (int i = 0; i < expectedShardSizes.length; i++) { @@ -163,12 +158,10 @@ public class InternalSnapshotsInfoServiceTests extends ESTestCase { latch.countDown(); - assertBusy(() -> { - assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(numberOfShards)); - assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0)); - assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(), equalTo(0)); - }); - verify(rerouteService, times(numberOfShards)).reroute(anyString(), any(Priority.class), any()); + assertTrue(rerouteLatch.await(30L, TimeUnit.SECONDS)); + assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(numberOfShards)); + assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0)); + assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(), equalTo(0)); assertThat(getShardSnapshotStatusCount.get(), equalTo(numberOfShards)); final SnapshotShardSizeInfo snapshotShardSizeInfo = snapshotsInfoService.snapshotShardSizes();