diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 4e269c85ee8..860cf46645f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -192,7 +192,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults new InternalAggregation.ReduceContextBuilder() { + @Override + public InternalAggregation.ReduceContext forPartialReduction() { + return InternalAggregation.ReduceContext.forPartialReduction( + BigArrays.NON_RECYCLING_INSTANCE, null, () -> PipelineAggregator.PipelineTree.EMPTY); + } + + public InternalAggregation.ReduceContext forFinalReduction() { + return InternalAggregation.ReduceContext.forFinalReduction( + BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineAggregator.PipelineTree.EMPTY); + }; + }); + threadPool = new TestThreadPool(SearchPhaseControllerTests.class.getName()); + executor = EsExecutors.newFixed("test", 1, 10, EsExecutors.daemonThreadFactory("test"), threadPool.getThreadContext()); + } + + @After + public void cleanup() { + executor.shutdownNow(); + terminate(threadPool); + } + + public void testProgressListenerExceptionsAreCaught() throws Exception { + + ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener(); + + List searchShards = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i))); + } + searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false); + + SearchRequest searchRequest = new SearchRequest("index"); + searchRequest.setBatchedReduceSize(2); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController, + searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + })); + + CountDownLatch partialReduceLatch = new CountDownLatch(10); + + for (int i = 0; i < 10; i++) { + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), + null, OriginalIndices.NONE); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(i); + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } + + assertEquals(10, searchProgressListener.onQueryResult.get()); + assertTrue(partialReduceLatch.await(10, TimeUnit.SECONDS)); + assertNull(onPartialMergeFailure.get()); + assertEquals(8, searchProgressListener.onPartialReduce.get()); + + queryPhaseResultConsumer.reduce(); + assertEquals(1, searchProgressListener.onFinalReduce.get()); + } + + private static class ThrowingSearchProgressListener extends SearchProgressListener { + private final AtomicInteger onQueryResult = new AtomicInteger(0); + private final AtomicInteger onPartialReduce = new AtomicInteger(0); + private final AtomicInteger onFinalReduce = new AtomicInteger(0); + + @Override + protected void onListShards(List shards, List skippedShards, SearchResponse.Clusters clusters, + boolean fetchPhase) { + throw new UnsupportedOperationException(); + } + + @Override + protected void onQueryResult(int shardIndex) { + onQueryResult.incrementAndGet(); + throw new UnsupportedOperationException(); + } + + @Override + protected void onPartialReduce(List shards, TotalHits totalHits, + DelayableWriteable.Serialized aggs, int reducePhase) { + onPartialReduce.incrementAndGet(); + throw new UnsupportedOperationException(); + } + + @Override + protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + onFinalReduce.incrementAndGet(); + throw new UnsupportedOperationException(); + } + } +}