diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index 95282e358e1..d3e8c069601 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -52,9 +52,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet; + public class SearchAsyncActionTests extends ESTestCase { public void testSkipSearchShards() throws InterruptedException { @@ -139,7 +143,7 @@ public class SearchAsyncActionTests extends ESTestCase { protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { @Override - public void run() throws IOException { + public void run() { latch.countDown(); } }; @@ -260,7 +264,6 @@ public class SearchAsyncActionTests extends ESTestCase { SearchRequest request = new SearchRequest(); request.allowPartialSearchResults(true); request.setMaxConcurrentShardRequests(randomIntBetween(1, 100)); - CountDownLatch latch = new CountDownLatch(1); AtomicReference response = new AtomicReference<>(); ActionListener responseListener = new ActionListener() { @Override @@ -277,7 +280,7 @@ public class SearchAsyncActionTests extends ESTestCase { DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT); DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT); - Map> nodeToContextMap = new HashMap<>(); + Map> nodeToContextMap = newConcurrentMap(); AtomicInteger contextIdGenerator = new AtomicInteger(0); GroupShardsIterator shardsIter = getShardsIter("idx", new OriginalIndices(new String[]{"idx"}, IndicesOptions.strictExpandOpenAndForbidClosed()), @@ -296,6 +299,8 @@ public class SearchAsyncActionTests extends ESTestCase { lookup.put(replicaNode.getId(), new MockConnection(replicaNode)); Map aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)); final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors())); + final CountDownLatch latch = new CountDownLatch(1); + final AtomicBoolean latchTriggered = new AtomicBoolean(); AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction( "test", @@ -326,7 +331,7 @@ public class SearchAsyncActionTests extends ESTestCase { Transport.Connection connection = getConnection(null, shard.currentNodeId()); TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(), connection.getNode()); - Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> new HashSet<>()); + Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet()); ids.add(testSearchPhaseResult.getRequestId()); if (randomBoolean()) { listener.onResponse(testSearchPhaseResult); @@ -339,15 +344,15 @@ public class SearchAsyncActionTests extends ESTestCase { protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { @Override - public void run() throws IOException { + public void run() { for (int i = 0; i < results.getNumShards(); i++) { TestSearchPhaseResult result = results.getAtomicArray().get(i); assertEquals(result.node.getId(), result.getSearchShardTarget().getNodeId()); sendReleaseSearchContext(result.getRequestId(), new MockConnection(result.node), OriginalIndices.NONE); } responseListener.onResponse(response); - if (latch.getCount() == 0) { - throw new AssertionError("Running a search phase after the latch has reached 0 !!!!"); + if (latchTriggered.compareAndSet(false, true) == false) { + throw new AssertionError("latch triggered twice"); } latch.countDown(); }