From 11445340936671a22749da44391ecc00418f55cb Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 28 Jul 2020 13:40:47 +0200 Subject: [PATCH] Executes incremental reduce in the search thread pool (#58461) (#60275) This change forks the execution of partial reduces in the coordinating node to the search thread pool. It also ensures that partial reduces are executed sequentially and asynchronously in order to limit the memory and cpu that a single search request can use but also to avoid blocking a network thread. If a partial reduce fails with an exception, the search request is cancelled and the reporting of the error is delayed to the start of the fetch phase (when the final reduce is performed). This ensures that we cleanup the in-flight search requests before returning an error to the user. Closes #53411 Relates #51857 --- .../action/RejectionActionIT.java | 10 +- .../SearchProgressActionListenerIT.java | 5 +- .../search/AbstractSearchAsyncAction.java | 7 +- .../search/ArraySearchPhaseResults.java | 4 +- .../search/CanMatchPreFilterSearchPhase.java | 8 +- .../action/search/CountedCollector.java | 12 +- .../action/search/DfsQueryPhase.java | 8 +- .../action/search/FetchSearchPhase.java | 11 +- .../search/QueryPhaseResultConsumer.java | 418 ++++++++++++++++++ .../SearchDfsQueryThenFetchAsyncAction.java | 8 +- .../action/search/SearchPhaseController.java | 385 ++++------------ .../action/search/SearchPhaseResults.java | 5 +- .../action/search/SearchProgressListener.java | 13 +- .../SearchQueryThenFetchAsyncAction.java | 8 +- .../action/search/TransportSearchAction.java | 28 +- .../search/query/QuerySearchResult.java | 12 + .../search/suggest/SuggestPhase.java | 4 + .../AbstractSearchAsyncActionTests.java | 2 +- .../action/search/CountedCollectorTests.java | 10 +- .../action/search/DfsQueryPhaseTests.java | 6 +- .../action/search/FetchSearchPhaseTests.java | 42 +- .../search/SearchPhaseControllerTests.java | 320 +++++++++----- .../SearchQueryThenFetchAsyncActionTests.java | 8 +- .../snapshots/SnapshotResiliencyTests.java | 2 +- 24 files changed, 835 insertions(+), 501 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/RejectionActionIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/RejectionActionIT.java index e0ef29bf7f4..345c0155548 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/RejectionActionIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/RejectionActionIT.java @@ -34,6 +34,8 @@ import java.util.Locale; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 2) @@ -85,8 +87,8 @@ public class RejectionActionIT extends ESIntegTestCase { if (response instanceof SearchResponse) { SearchResponse searchResponse = (SearchResponse) response; for (ShardSearchFailure failure : searchResponse.getShardFailures()) { - assertTrue("got unexpected reason..." + failure.reason(), - failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected")); + assertThat(failure.reason().toLowerCase(Locale.ENGLISH), + anyOf(containsString("cancelled"), containsString("rejected"))); } } else { Exception t = (Exception) response; @@ -94,8 +96,8 @@ public class RejectionActionIT extends ESIntegTestCase { if (unwrap instanceof SearchPhaseExecutionException) { SearchPhaseExecutionException e = (SearchPhaseExecutionException) unwrap; for (ShardSearchFailure failure : e.shardFailures()) { - assertTrue("got unexpected reason..." + failure.reason(), - failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected")); + assertThat(failure.reason().toLowerCase(Locale.ENGLISH), + anyOf(containsString("cancelled"), containsString("rejected"))); } } else if ((unwrap instanceof EsRejectedExecutionException) == false) { throw new AssertionError("unexpected failure", (Throwable) response); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java index 85d21c22ab7..de4d4e66f88 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java @@ -131,8 +131,8 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase { testCase((NodeClient) client(), request, sortShards, false); } - private static void testCase(NodeClient client, SearchRequest request, - List expectedShards, boolean hasFetchPhase) throws InterruptedException { + private void testCase(NodeClient client, SearchRequest request, + List expectedShards, boolean hasFetchPhase) throws InterruptedException { AtomicInteger numQueryResults = new AtomicInteger(); AtomicInteger numQueryFailures = new AtomicInteger(); AtomicInteger numFetchResults = new AtomicInteger(); @@ -204,7 +204,6 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase { } }, listener); latch.await(); - assertThat(shardsListener.get(), equalTo(expectedShards)); assertThat(numQueryResults.get(), equalTo(searchResponse.get().getSuccessfulShards())); assertThat(numQueryFailures.get(), equalTo(searchResponse.get().getFailedShards())); diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 66548adccb5..1c1b4ad9eee 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -470,12 +470,15 @@ abstract class AbstractSearchAsyncAction exten protected void onShardResult(Result result, SearchShardIterator shardIt) { assert result.getShardIndex() != -1 : "shard index is not set"; assert result.getSearchShardTarget() != null : "search shard target must not be null"; - successfulOps.incrementAndGet(); - results.consumeResult(result); hasShardResponse.set(true); if (logger.isTraceEnabled()) { logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null); } + results.consumeResult(result, () -> onShardResultConsumed(result, shardIt)); + } + + private void onShardResultConsumed(Result result, SearchShardIterator shardIt) { + successfulOps.incrementAndGet(); // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level // so its ok concurrency wise to miss potentially the shard failures being created because of another failure // in the #addShardFailure, because by definition, it will happen on *another* shardIndex diff --git a/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java b/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java index 4e84ef99f09..dc7cbbe0eff 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java +++ b/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java @@ -39,9 +39,11 @@ class ArraySearchPhaseResults extends SearchPh return results.asList().stream(); } - void consumeResult(Result result) { + @Override + void consumeResult(Result result, Runnable next) { assert results.get(result.getShardIndex()) == null : "shardIndex: " + result.getShardIndex() + " is already set"; results.set(result.getShardIndex(), result); + next.run(); } boolean hasResult(int shardIndex) { diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index 337861d0dc5..f594ba80f8e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -159,8 +159,12 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction { - private final Consumer resultConsumer; + private final ArraySearchPhaseResults resultConsumer; private final CountDown counter; private final Runnable onFinish; private final SearchPhaseContext context; - CountedCollector(Consumer resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) { + CountedCollector(ArraySearchPhaseResults resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) { this.resultConsumer = resultConsumer; this.counter = new CountDown(expectedOps); this.onFinish = onFinish; @@ -58,11 +56,7 @@ final class CountedCollector { * Sets the result to the given array index and then runs {@link #countDown()} */ void onResult(R result) { - try { - resultConsumer.accept(result); - } finally { - countDown(); - } + resultConsumer.consumeResult(result, this::countDown); } /** diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index 8352469042a..a37fbb0e14d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -30,6 +30,7 @@ import org.elasticsearch.transport.Transport; import java.io.IOException; import java.util.List; +import java.util.function.Consumer; import java.util.function.Function; /** @@ -51,10 +52,11 @@ final class DfsQueryPhase extends SearchPhase { DfsQueryPhase(AtomicArray dfsSearchResults, SearchPhaseController searchPhaseController, Function, SearchPhase> nextPhaseFactory, - SearchPhaseContext context) { + SearchPhaseContext context, Consumer onPartialMergeFailure) { super("dfs_query"); this.progressListener = context.getTask().getProgressListener(); - this.queryResult = searchPhaseController.newSearchPhaseResults(progressListener, context.getRequest(), context.getNumShards()); + this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener, + context.getRequest(), context.getNumShards(), onPartialMergeFailure); this.searchPhaseController = searchPhaseController; this.dfsSearchResults = dfsSearchResults; this.nextPhaseFactory = nextPhaseFactory; @@ -68,7 +70,7 @@ final class DfsQueryPhase extends SearchPhase { // to free up memory early final List resultList = dfsSearchResults.asList(); final AggregatedDfs dfs = searchPhaseController.aggregateDfs(resultList); - final CountedCollector counter = new CountedCollector<>(queryResult::consumeResult, + final CountedCollector counter = new CountedCollector<>(queryResult, resultList.size(), () -> context.executeNextPhase(this, nextPhaseFactory.apply(queryResult)), context); for (final DfsSearchResult dfsResult : resultList) { diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 88cf1298700..4b2abcf271c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -36,7 +36,6 @@ import org.elasticsearch.search.internal.SearchContextId; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.transport.Transport; -import java.io.IOException; import java.util.List; import java.util.function.BiFunction; @@ -45,7 +44,7 @@ import java.util.function.BiFunction; * Then it reaches out to all relevant shards to fetch the topN hits. */ final class FetchSearchPhase extends SearchPhase { - private final AtomicArray fetchResults; + private final ArraySearchPhaseResults fetchResults; private final SearchPhaseController searchPhaseController; private final AtomicArray queryResults; private final BiFunction nextPhaseFactory; @@ -73,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase { throw new IllegalStateException("number of shards must match the length of the query results but doesn't:" + context.getNumShards() + "!=" + resultConsumer.getNumShards()); } - this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards()); + this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards()); this.searchPhaseController = searchPhaseController; this.queryResults = resultConsumer.getAtomicArray(); this.nextPhaseFactory = nextPhaseFactory; @@ -102,7 +101,7 @@ final class FetchSearchPhase extends SearchPhase { }); } - private void innerRun() throws IOException { + private void innerRun() throws Exception { final int numShards = context.getNumShards(); final boolean isScrollSearch = context.getRequest().scroll() != null; final List phaseResults = queryResults.asList(); @@ -117,7 +116,7 @@ final class FetchSearchPhase extends SearchPhase { final boolean queryAndFetchOptimization = queryResults.length() == 1; final Runnable finishPhase = () -> moveToNextPhase(searchPhaseController, scrollId, reducedQueryPhase, queryAndFetchOptimization ? - queryResults : fetchResults); + queryResults : fetchResults.getAtomicArray()); if (queryAndFetchOptimization) { assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + phaseResults.isEmpty() + "], single result: " + phaseResults.get(0).fetchResult(); @@ -137,7 +136,7 @@ final class FetchSearchPhase extends SearchPhase { final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ? searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, numShards) : null; - final CountedCollector counter = new CountedCollector<>(r -> fetchResults.set(r.getShardIndex(), r), + final CountedCollector counter = new CountedCollector<>(fetchResults, docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not finishPhase, context); for (int i = 0; i < docIdsToLoad.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java new file mode 100644 index 00000000000..b62b6968e37 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -0,0 +1,418 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.action.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats; +import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContextBuilder; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.query.QuerySearchResult; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs; +import static org.elasticsearch.action.search.SearchPhaseController.setShardIndex; + +/** + * A {@link ArraySearchPhaseResults} implementation that incrementally reduces aggregation results + * as shard results are consumed. + * This implementation can be configured to batch up a certain amount of results and reduce + * them asynchronously in the provided {@link Executor} iff the buffer is exhausted. + */ +class QueryPhaseResultConsumer extends ArraySearchPhaseResults { + private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class); + + private final Executor executor; + private final SearchPhaseController controller; + private final SearchProgressListener progressListener; + private final ReduceContextBuilder aggReduceContextBuilder; + private final NamedWriteableRegistry namedWriteableRegistry; + + private final int topNSize; + private final boolean hasTopDocs; + private final boolean hasAggs; + private final boolean performFinalReduce; + + private final PendingMerges pendingMerges; + private final Consumer onPartialMergeFailure; + + private volatile long aggsMaxBufferSize; + private volatile long aggsCurrentBufferSize; + + /** + * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results + * as shard results are consumed. + */ + QueryPhaseResultConsumer(Executor executor, + SearchPhaseController controller, + SearchProgressListener progressListener, + ReduceContextBuilder aggReduceContextBuilder, + NamedWriteableRegistry namedWriteableRegistry, + int expectedResultSize, + int bufferSize, + boolean hasTopDocs, + boolean hasAggs, + int trackTotalHitsUpTo, + int topNSize, + boolean performFinalReduce, + Consumer onPartialMergeFailure) { + super(expectedResultSize); + this.executor = executor; + this.controller = controller; + this.progressListener = progressListener; + this.aggReduceContextBuilder = aggReduceContextBuilder; + this.namedWriteableRegistry = namedWriteableRegistry; + this.topNSize = topNSize; + this.pendingMerges = new PendingMerges(bufferSize, trackTotalHitsUpTo); + this.hasTopDocs = hasTopDocs; + this.hasAggs = hasAggs; + this.performFinalReduce = performFinalReduce; + this.onPartialMergeFailure = onPartialMergeFailure; + } + + @Override + void consumeResult(SearchPhaseResult result, Runnable next) { + super.consumeResult(result, () -> {}); + QuerySearchResult querySearchResult = result.queryResult(); + progressListener.notifyQueryResult(querySearchResult.getShardIndex()); + pendingMerges.consume(querySearchResult, next); + } + + @Override + SearchPhaseController.ReducedQueryPhase reduce() throws Exception { + if (pendingMerges.hasPendingMerges()) { + throw new AssertionError("partial reduce in-flight"); + } else if (pendingMerges.hasFailure()) { + throw pendingMerges.getFailure(); + } + + logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize); + // ensure consistent ordering + pendingMerges.sortBuffer(); + final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); + final List topDocsList = pendingMerges.consumeTopDocs(); + final List aggsList = pendingMerges.consumeAggs(); + SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList, + topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce); + progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()), + reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); + return reducePhase; + } + + private MergeResult partialReduce(MergeTask task, + TopDocsStats topDocsStats, + MergeResult lastMerge, + int numReducePhases) { + final QuerySearchResult[] toConsume = task.consumeBuffer(); + if (toConsume == null) { + // the task is cancelled + return null; + } + // ensure consistent ordering + Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex)); + + for (QuerySearchResult result : toConsume) { + topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); + } + + final TopDocs newTopDocs; + if (hasTopDocs) { + List topDocsList = new ArrayList<>(); + if (lastMerge != null) { + topDocsList.add(lastMerge.reducedTopDocs); + } + for (QuerySearchResult result : toConsume) { + TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + setShardIndex(topDocs.topDocs, result.getShardIndex()); + topDocsList.add(topDocs.topDocs); + } + newTopDocs = mergeTopDocs(topDocsList, + // we have to merge here in the same way we collect on a shard + topNSize, 0); + } else { + newTopDocs = null; + } + + final DelayableWriteable.Serialized newAggs; + if (hasAggs) { + List aggsList = new ArrayList<>(); + if (lastMerge != null) { + aggsList.add(lastMerge.reducedAggs.expand()); + } + for (QuerySearchResult result : toConsume) { + aggsList.add(result.consumeAggs().expand()); + } + InternalAggregations result = InternalAggregations.topLevelReduce(aggsList, + aggReduceContextBuilder.forPartialReduction()); + newAggs = DelayableWriteable.referencing(result).asSerialized(InternalAggregations::readFrom, namedWriteableRegistry); + long previousBufferSize = aggsCurrentBufferSize; + aggsCurrentBufferSize = newAggs.ramBytesUsed(); + aggsMaxBufferSize = Math.max(aggsCurrentBufferSize, aggsMaxBufferSize); + logger.trace("aggs partial reduction [{}->{}] max [{}]", + previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize); + } else { + newAggs = null; + } + List processedShards = new ArrayList<>(task.emptyResults); + if (lastMerge != null) { + processedShards.addAll(lastMerge.processedShards); + } + for (QuerySearchResult result : toConsume) { + SearchShardTarget target = result.getSearchShardTarget(); + processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); + } + progressListener.onPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); + return new MergeResult(processedShards, newTopDocs, newAggs); + } + + public int getNumReducePhases() { + return pendingMerges.numReducePhases; + } + + private class PendingMerges { + private final int bufferSize; + + private int index; + private final QuerySearchResult[] buffer; + private final List emptyResults = new ArrayList<>(); + + private final TopDocsStats topDocsStats; + private MergeResult mergeResult; + private final ArrayDeque queue = new ArrayDeque<>(); + private final AtomicReference runningTask = new AtomicReference<>(); + private final AtomicReference failure = new AtomicReference<>(); + + private boolean hasPartialReduce; + private int numReducePhases; + + PendingMerges(int bufferSize, int trackTotalHitsUpTo) { + this.bufferSize = bufferSize; + this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo); + this.buffer = new QuerySearchResult[bufferSize]; + } + + public boolean hasFailure() { + return failure.get() != null; + } + + public synchronized boolean hasPendingMerges() { + return queue.isEmpty() == false || runningTask.get() != null; + } + + public synchronized void sortBuffer() { + if (index > 0) { + Arrays.sort(buffer, 0, index, Comparator.comparingInt(QuerySearchResult::getShardIndex)); + } + } + + public void consume(QuerySearchResult result, Runnable next) { + boolean executeNextImmediately = true; + synchronized (this) { + if (hasFailure() || result.isNull()) { + result.consumeAll(); + if (result.isNull()) { + SearchShardTarget target = result.getSearchShardTarget(); + emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId())); + } + } else { + // add one if a partial merge is pending + int size = index + (hasPartialReduce ? 1 : 0); + if (size >= bufferSize) { + hasPartialReduce = true; + executeNextImmediately = false; + QuerySearchResult[] clone = new QuerySearchResult[index]; + System.arraycopy(buffer, 0, clone, 0, index); + MergeTask task = new MergeTask(clone, new ArrayList<>(emptyResults), next); + Arrays.fill(buffer, null); + emptyResults.clear(); + index = 0; + queue.add(task); + tryExecuteNext(); + } + buffer[index++] = result; + } + } + if (executeNextImmediately) { + next.run(); + } + } + + private void onMergeFailure(Exception exc) { + synchronized (this) { + if (failure.get() != null) { + return; + } + failure.compareAndSet(null, exc); + MergeTask task = runningTask.get(); + if (task != null) { + runningTask.compareAndSet(task, null); + task.cancel(); + } + queue.stream().forEach(MergeTask::cancel); + queue.clear(); + mergeResult = null; + } + onPartialMergeFailure.accept(exc); + } + + private void onAfterMerge(MergeTask task, MergeResult newResult) { + synchronized (this) { + runningTask.compareAndSet(task, null); + mergeResult = newResult; + } + task.consumeListener(); + } + + private void tryExecuteNext() { + final MergeTask task; + synchronized (this) { + if (queue.isEmpty() + || failure.get() != null + || runningTask.get() != null) { + return; + } + task = queue.poll(); + runningTask.compareAndSet(null, task); + } + executor.execute(new AbstractRunnable() { + @Override + protected void doRun() { + final MergeResult newMerge; + try { + newMerge = partialReduce(task, topDocsStats, mergeResult, ++numReducePhases); + } catch (Exception t) { + onMergeFailure(t); + return; + } + onAfterMerge(task, newMerge); + tryExecuteNext(); + } + + @Override + public void onFailure(Exception exc) { + onMergeFailure(exc); + } + }); + } + + public TopDocsStats consumeTopDocsStats() { + for (int i = 0; i < index; i++) { + QuerySearchResult result = buffer[i]; + topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); + } + return topDocsStats; + } + + public List consumeTopDocs() { + if (hasTopDocs == false) { + return Collections.emptyList(); + } + List topDocsList = new ArrayList<>(); + if (mergeResult != null) { + topDocsList.add(mergeResult.reducedTopDocs); + } + for (int i = 0; i < index; i++) { + QuerySearchResult result = buffer[i]; + TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + setShardIndex(topDocs.topDocs, result.getShardIndex()); + topDocsList.add(topDocs.topDocs); + } + return topDocsList; + } + + public List consumeAggs() { + if (hasAggs == false) { + return Collections.emptyList(); + } + List aggsList = new ArrayList<>(); + if (mergeResult != null) { + aggsList.add(mergeResult.reducedAggs.expand()); + } + for (int i = 0; i < index; i++) { + QuerySearchResult result = buffer[i]; + aggsList.add(result.consumeAggs().expand()); + } + return aggsList; + } + + public Exception getFailure() { + return failure.get(); + } + } + + private static class MergeResult { + private final List processedShards; + private final TopDocs reducedTopDocs; + private final DelayableWriteable.Serialized reducedAggs; + + private MergeResult(List processedShards, TopDocs reducedTopDocs, + DelayableWriteable.Serialized reducedAggs) { + this.processedShards = processedShards; + this.reducedTopDocs = reducedTopDocs; + this.reducedAggs = reducedAggs; + } + } + + private static class MergeTask { + private final List emptyResults; + private QuerySearchResult[] buffer; + private Runnable next; + + private MergeTask(QuerySearchResult[] buffer, List emptyResults, Runnable next) { + this.buffer = buffer; + this.emptyResults = emptyResults; + this.next = next; + } + + public synchronized QuerySearchResult[] consumeBuffer() { + QuerySearchResult[] toRet = buffer; + buffer = null; + return toRet; + } + + public synchronized void consumeListener() { + if (next != null) { + next.run(); + next = null; + } + } + + public synchronized void cancel() { + consumeBuffer(); + consumeListener(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 0eecfce9e1e..6c55e6fe332 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -33,10 +33,12 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; +import java.util.function.Consumer; final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { private final SearchPhaseController searchPhaseController; + private final Consumer onPartialMergeFailure; SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService, final BiFunction nodeIdToConnection, @@ -46,12 +48,14 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction final SearchRequest request, final ActionListener listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, - final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) { + final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters, + Consumer onPartialMergeFailure) { super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, executor, request, listener, shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), request.getMaxConcurrentShardRequests(), clusters); this.searchPhaseController = searchPhaseController; + this.onPartialMergeFailure = onPartialMergeFailure; SearchProgressListener progressListener = task.getProgressListener(); SearchSourceBuilder sourceBuilder = request.source(); progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), @@ -68,6 +72,6 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { return new DfsQueryPhase(results.getAtomicArray(), searchPhaseController, (queryResults) -> - new FetchSearchPhase(queryResults, searchPhaseController, context, clusterState()), context); + new FetchSearchPhase(queryResults, searchPhaseController, context, clusterState()), context, onPartialMergeFailure); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index f4342592fed..e90939e6a0c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -22,8 +22,6 @@ package org.elasticsearch.action.search; import com.carrotsearch.hppc.IntArrayList; import com.carrotsearch.hppc.ObjectObjectHashMap; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.index.Term; import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.FieldDoc; @@ -37,7 +35,6 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; import org.elasticsearch.common.collect.HppcMaps; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.search.DocValueFormat; @@ -45,7 +42,6 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; -import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -63,18 +59,18 @@ import org.elasticsearch.search.suggest.Suggest.Suggestion; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.Executor; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.IntFunction; import java.util.stream.Collectors; public final class SearchPhaseController { - private static final Logger logger = LogManager.getLogger(SearchPhaseController.class); private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0]; private final NamedWriteableRegistry namedWriteableRegistry; @@ -151,77 +147,50 @@ public final class SearchPhaseController { * * @param ignoreFrom Whether to ignore the from and sort all hits in each shard result. * Enabled only for scroll search, because that only retrieves hits of length 'size' in the query phase. - * @param results the search phase results to obtain the sort docs from - * @param bufferedTopDocs the pre-consumed buffered top docs - * @param topDocsStats the top docs stats to fill + * @param topDocs the buffered top docs * @param from the offset into the search results top docs * @param size the number of hits to return from the merged top docs */ - static SortedTopDocs sortDocs(boolean ignoreFrom, Collection results, - final Collection bufferedTopDocs, final TopDocsStats topDocsStats, int from, int size, + static SortedTopDocs sortDocs(boolean ignoreFrom, final Collection topDocs, int from, int size, List reducedCompletionSuggestions) { - if (results.isEmpty()) { + if (topDocs.isEmpty() && reducedCompletionSuggestions.isEmpty()) { return SortedTopDocs.EMPTY; } - final Collection topDocs = bufferedTopDocs == null ? new ArrayList<>() : bufferedTopDocs; - for (SearchPhaseResult sortedResult : results) { // TODO we can move this loop into the reduce call to only loop over this once - /* We loop over all results once, group together the completion suggestions if there are any and collect relevant - * top docs results. Each top docs gets it's shard index set on all top docs to simplify top docs merging down the road - * this allowed to remove a single shared optimization code here since now we don't materialized a dense array of - * top docs anymore but instead only pass relevant results / top docs to the merge method*/ - QuerySearchResult queryResult = sortedResult.queryResult(); - if (queryResult.hasConsumedTopDocs() == false) { // already consumed? - final TopDocsAndMaxScore td = queryResult.consumeTopDocs(); - assert td != null; - topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly()); - // make sure we set the shard index before we add it - the consumer didn't do that yet - if (td.topDocs.scoreDocs.length > 0) { - setShardIndex(td.topDocs, queryResult.getShardIndex()); - topDocs.add(td.topDocs); + final TopDocs mergedTopDocs = mergeTopDocs(topDocs, size, ignoreFrom ? 0 : from); + final ScoreDoc[] mergedScoreDocs = mergedTopDocs == null ? EMPTY_DOCS : mergedTopDocs.scoreDocs; + ScoreDoc[] scoreDocs = mergedScoreDocs; + if (reducedCompletionSuggestions.isEmpty() == false) { + int numSuggestDocs = 0; + for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) { + assert completionSuggestion != null; + numSuggestDocs += completionSuggestion.getOptions().size(); + } + scoreDocs = new ScoreDoc[mergedScoreDocs.length + numSuggestDocs]; + System.arraycopy(mergedScoreDocs, 0, scoreDocs, 0, mergedScoreDocs.length); + int offset = mergedScoreDocs.length; + for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) { + for (CompletionSuggestion.Entry.Option option : completionSuggestion.getOptions()) { + scoreDocs[offset++] = option.getDoc(); } } } - final boolean hasHits = (reducedCompletionSuggestions.isEmpty() && topDocs.isEmpty()) == false; - if (hasHits) { - final TopDocs mergedTopDocs = mergeTopDocs(topDocs, size, ignoreFrom ? 0 : from); - final ScoreDoc[] mergedScoreDocs = mergedTopDocs == null ? EMPTY_DOCS : mergedTopDocs.scoreDocs; - ScoreDoc[] scoreDocs = mergedScoreDocs; - if (reducedCompletionSuggestions.isEmpty() == false) { - int numSuggestDocs = 0; - for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) { - assert completionSuggestion != null; - numSuggestDocs += completionSuggestion.getOptions().size(); - } - scoreDocs = new ScoreDoc[mergedScoreDocs.length + numSuggestDocs]; - System.arraycopy(mergedScoreDocs, 0, scoreDocs, 0, mergedScoreDocs.length); - int offset = mergedScoreDocs.length; - for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) { - for (CompletionSuggestion.Entry.Option option : completionSuggestion.getOptions()) { - scoreDocs[offset++] = option.getDoc(); - } - } + boolean isSortedByField = false; + SortField[] sortFields = null; + String collapseField = null; + Object[] collapseValues = null; + if (mergedTopDocs instanceof TopFieldDocs) { + TopFieldDocs fieldDocs = (TopFieldDocs) mergedTopDocs; + sortFields = fieldDocs.fields; + if (fieldDocs instanceof CollapseTopFieldDocs) { + isSortedByField = (fieldDocs.fields.length == 1 && fieldDocs.fields[0].getType() == SortField.Type.SCORE) == false; + CollapseTopFieldDocs collapseTopFieldDocs = (CollapseTopFieldDocs) fieldDocs; + collapseField = collapseTopFieldDocs.field; + collapseValues = collapseTopFieldDocs.collapseValues; + } else { + isSortedByField = true; } - boolean isSortedByField = false; - SortField[] sortFields = null; - String collapseField = null; - Object[] collapseValues = null; - if (mergedTopDocs instanceof TopFieldDocs) { - TopFieldDocs fieldDocs = (TopFieldDocs) mergedTopDocs; - sortFields = fieldDocs.fields; - if (fieldDocs instanceof CollapseTopFieldDocs) { - isSortedByField = (fieldDocs.fields.length == 1 && fieldDocs.fields[0].getType() == SortField.Type.SCORE) == false; - CollapseTopFieldDocs collapseTopFieldDocs = (CollapseTopFieldDocs) fieldDocs; - collapseField = collapseTopFieldDocs.field; - collapseValues = collapseTopFieldDocs.collapseValues; - } else { - isSortedByField = true; - } - } - return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, collapseField, collapseValues); - } else { - // no relevant docs - return SortedTopDocs.EMPTY; } + return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, collapseField, collapseValues); } static TopDocs mergeTopDocs(Collection results, int topN, int from) { @@ -251,7 +220,7 @@ public final class SearchPhaseController { return mergedTopDocs; } - private static void setShardIndex(TopDocs topDocs, int shardIndex) { + static void setShardIndex(TopDocs topDocs, int shardIndex) { assert topDocs.scoreDocs.length == 0 || topDocs.scoreDocs[0].shardIndex == -1 : "shardIndex is already set"; for (ScoreDoc doc : topDocs.scoreDocs) { doc.shardIndex = shardIndex; @@ -409,38 +378,38 @@ public final class SearchPhaseController { throw new UnsupportedOperationException("Scroll requests don't have aggs"); } }; - return reducedQueryPhase(queryResults, true, SearchContext.TRACK_TOTAL_HITS_ACCURATE, aggReduceContextBuilder, true); + final TopDocsStats topDocsStats = new TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE); + final List topDocs = new ArrayList<>(); + for (SearchPhaseResult sortedResult : queryResults) { + QuerySearchResult queryResult = sortedResult.queryResult(); + final TopDocsAndMaxScore td = queryResult.consumeTopDocs(); + assert td != null; + topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly()); + // make sure we set the shard index before we add it - the consumer didn't do that yet + if (td.topDocs.scoreDocs.length > 0) { + setShardIndex(td.topDocs, queryResult.getShardIndex()); + topDocs.add(td.topDocs); + } + } + return reducedQueryPhase(queryResults, Collections.emptyList(), topDocs, topDocsStats, + 0, true, aggReduceContextBuilder, true); } /** * Reduces the given query results and consumes all aggregations and profile results. * @param queryResults a list of non-null query shard results - */ - public ReducedQueryPhase reducedQueryPhase(Collection queryResults, - boolean isScrollRequest, int trackTotalHitsUpTo, - InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, - boolean performFinalReduce) { - return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHitsUpTo), - 0, isScrollRequest, aggReduceContextBuilder, performFinalReduce); - } - - /** - * Reduces the given query results and consumes all aggregations and profile results. - * @param queryResults a list of non-null query shard results - * @param bufferedAggs a list of pre-collected / buffered aggregations. if this list is non-null all aggregations have been consumed - * from all non-null query results. - * @param bufferedTopDocs a list of pre-collected / buffered top docs. if this list is non-null all top docs have been consumed - * from all non-null query results. + * @param bufferedAggs a list of pre-collected aggregations. + * @param bufferedTopDocs a list of pre-collected top docs. * @param numReducePhases the number of non-final reduce phases applied to the query results. * @see QuerySearchResult#consumeAggs() * @see QuerySearchResult#consumeProfileResult() */ - private ReducedQueryPhase reducedQueryPhase(Collection queryResults, - List> bufferedAggs, - List bufferedTopDocs, - TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest, - InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, - boolean performFinalReduce) { + ReducedQueryPhase reducedQueryPhase(Collection queryResults, + List bufferedAggs, + List bufferedTopDocs, + TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest, + InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, + boolean performFinalReduce) { assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases; numReducePhases++; // increment for this phase if (queryResults.isEmpty()) { // early terminate we have nothing to reduce @@ -460,22 +429,6 @@ public final class SearchPhaseController { final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult(); final boolean hasSuggest = firstResult.suggest() != null; final boolean hasProfileResults = firstResult.hasProfileResults(); - final boolean consumeAggs; - final List> aggregationsList; - if (bufferedAggs != null) { - consumeAggs = false; - // we already have results from intermediate reduces and just need to perform the final reduce - assert firstResult.hasAggs() : "firstResult has no aggs but we got non null buffered aggs?"; - aggregationsList = bufferedAggs; - } else if (firstResult.hasAggs()) { - // the number of shards was less than the buffer size so we reduce agg results directly - aggregationsList = new ArrayList<>(queryResults.size()); - consumeAggs = true; - } else { - // no aggregations - aggregationsList = Collections.emptyList(); - consumeAggs = false; - } // count the total (we use the query result provider here, since we might not get any hits (we scrolled past them)) final Map> groupedSuggestions = hasSuggest ? new HashMap<>() : Collections.emptyMap(); @@ -499,8 +452,8 @@ public final class SearchPhaseController { } } } - if (consumeAggs) { - aggregationsList.add(result.consumeAggs()); + if (bufferedTopDocs.isEmpty() == false) { + assert result.hasConsumedTopDocs() : "firstResult has no aggs but we got non null buffered aggs?"; } if (hasProfileResults) { String key = result.getSearchShardTarget().toString(); @@ -516,31 +469,19 @@ public final class SearchPhaseController { reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions)); reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class); } - final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, aggregationsList); + final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs); final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults); - final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size, - reducedCompletionSuggestions); + final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, bufferedTopDocs, from, size, reducedCompletionSuggestions); final TotalHits totalHits = topDocsStats.getTotalHits(); return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(), topDocsStats.timedOut, topDocsStats.terminatedEarly, reducedSuggest, aggregations, shardResults, sortedTopDocs, firstResult.sortValueFormats(), numReducePhases, size, from, false); } - private static InternalAggregations reduceAggs( - InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, - boolean performFinalReduce, - List> aggregationsList - ) { - /* - * Parse the aggregations, clearing the list as we go so bits backing - * the DelayedWriteable can be collected immediately. - */ - List toReduce = new ArrayList<>(aggregationsList.size()); - for (int i = 0; i < aggregationsList.size(); i++) { - toReduce.add(aggregationsList.get(i).expand()); - aggregationsList.set(i, null); - } - return aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce, + private static InternalAggregations reduceAggs(InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, + boolean performFinalReduce, + List toReduce) { + return toReduce.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce, performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction()); } @@ -617,199 +558,23 @@ public final class SearchPhaseController { } } - /** - * A {@link ArraySearchPhaseResults} implementation - * that incrementally reduces aggregation results as shard results are consumed. - * This implementation can be configured to batch up a certain amount of results and only reduce them - * iff the buffer is exhausted. - */ - static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults { - private final NamedWriteableRegistry namedWriteableRegistry; - private final SearchShardTarget[] processedShards; - private final DelayableWriteable.Serialized[] aggsBuffer; - private final TopDocs[] topDocsBuffer; - private final boolean hasAggs; - private final boolean hasTopDocs; - private final int bufferSize; - private int index; - private final SearchPhaseController controller; - private final SearchProgressListener progressListener; - private int numReducePhases = 0; - private final TopDocsStats topDocsStats; - private final int topNSize; - private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder; - private final boolean performFinalReduce; - private long aggsCurrentBufferSize; - private long aggsMaxBufferSize; - - /** - * Creates a new {@link QueryPhaseResultConsumer} - * @param progressListener a progress listener to be notified when a successful response is received - * and when a partial or final reduce has completed. - * @param controller a controller instance to reduce the query response objects - * @param expectedResultSize the expected number of query results. Corresponds to the number of shards queried - * @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results - * the buffer is used to incrementally reduce aggregation results before all shards responded. - */ - private QueryPhaseResultConsumer(NamedWriteableRegistry namedWriteableRegistry, SearchProgressListener progressListener, - SearchPhaseController controller, - int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs, - int trackTotalHitsUpTo, int topNSize, - InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, - boolean performFinalReduce) { - super(expectedResultSize); - this.namedWriteableRegistry = namedWriteableRegistry; - if (expectedResultSize != 1 && bufferSize < 2) { - throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result"); - } - if (expectedResultSize <= bufferSize) { - throw new IllegalArgumentException("buffer size must be less than the expected result size"); - } - if (hasAggs == false && hasTopDocs == false) { - throw new IllegalArgumentException("either aggs or top docs must be present"); - } - this.controller = controller; - this.progressListener = progressListener; - this.processedShards = new SearchShardTarget[expectedResultSize]; - // no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time. - @SuppressWarnings("unchecked") - DelayableWriteable.Serialized[] aggsBuffer = new DelayableWriteable.Serialized[hasAggs ? bufferSize : 0]; - this.aggsBuffer = aggsBuffer; - this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0]; - this.hasTopDocs = hasTopDocs; - this.hasAggs = hasAggs; - this.bufferSize = bufferSize; - this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo); - this.topNSize = topNSize; - this.aggReduceContextBuilder = aggReduceContextBuilder; - this.performFinalReduce = performFinalReduce; - } - - @Override - public void consumeResult(SearchPhaseResult result) { - super.consumeResult(result); - QuerySearchResult queryResult = result.queryResult(); - consumeInternal(queryResult); - progressListener.notifyQueryResult(queryResult.getShardIndex()); - } - - private synchronized void consumeInternal(QuerySearchResult querySearchResult) { - if (querySearchResult.isNull() == false) { - if (index == bufferSize) { - DelayableWriteable.Serialized reducedAggs = null; - if (hasAggs) { - List aggs = new ArrayList<>(aggsBuffer.length); - for (int i = 0; i < aggsBuffer.length; i++) { - aggs.add(aggsBuffer[i].expand()); - aggsBuffer[i] = null; // null the buffer so it can be GCed now. - } - InternalAggregations reduced = - InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction()); - reducedAggs = aggsBuffer[0] = DelayableWriteable.referencing(reduced) - .asSerialized(InternalAggregations::readFrom, namedWriteableRegistry); - long previousBufferSize = aggsCurrentBufferSize; - aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize); - aggsCurrentBufferSize = aggsBuffer[0].ramBytesUsed(); - logger.trace("aggs partial reduction [{}->{}] max [{}]", - previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize); - } - if (hasTopDocs) { - TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer), - // we have to merge here in the same way we collect on a shard - topNSize, 0); - Arrays.fill(topDocsBuffer, null); - topDocsBuffer[0] = reducedTopDocs; - } - numReducePhases++; - index = 1; - if (hasAggs || hasTopDocs) { - progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards), - topDocsStats.getTotalHits(), reducedAggs, numReducePhases); - } - } - final int i = index++; - if (hasAggs) { - aggsBuffer[i] = querySearchResult.consumeAggs().asSerialized(InternalAggregations::readFrom, namedWriteableRegistry); - aggsCurrentBufferSize += aggsBuffer[i].ramBytesUsed(); - } - if (hasTopDocs) { - final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null - topDocsStats.add(topDocs, querySearchResult.searchTimedOut(), querySearchResult.terminatedEarly()); - setShardIndex(topDocs.topDocs, querySearchResult.getShardIndex()); - topDocsBuffer[i] = topDocs.topDocs; - } - } - processedShards[querySearchResult.getShardIndex()] = querySearchResult.getSearchShardTarget(); - } - - private synchronized List> getRemainingAggs() { - return hasAggs ? Arrays.asList((DelayableWriteable[]) aggsBuffer).subList(0, index) : null; - } - - private synchronized List getRemainingTopDocs() { - return hasTopDocs ? Arrays.asList(topDocsBuffer).subList(0, index) : null; - } - - @Override - public ReducedQueryPhase reduce() { - aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize); - logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize); - ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), - getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, - aggReduceContextBuilder, performFinalReduce); - progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()), - reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); - return reducePhase; - } - - /** - * Returns the number of buffered results - */ - int getNumBuffered() { - return index; - } - - int getNumReducePhases() { return numReducePhases; } - } - /** * Returns a new ArraySearchPhaseResults instance. This might return an instance that reduces search responses incrementally. */ - ArraySearchPhaseResults newSearchPhaseResults(SearchProgressListener listener, + ArraySearchPhaseResults newSearchPhaseResults(Executor executor, + SearchProgressListener listener, SearchRequest request, - int numShards) { + int numShards, + Consumer onPartialMergeFailure) { SearchSourceBuilder source = request.source(); - boolean isScrollRequest = request.scroll() != null; final boolean hasAggs = source != null && source.aggregations() != null; final boolean hasTopDocs = source == null || source.size() != 0; final int trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo(); InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = requestToAggReduceContextBuilder.apply(request); - if (isScrollRequest == false && (hasAggs || hasTopDocs)) { - // no incremental reduce if scroll is used - we only hit a single shard or sometimes more... - if (request.getBatchedReduceSize() < numShards) { - int topNSize = getTopDocsSize(request); - // only use this if there are aggs and if there are more shards than we should reduce at once - return new QueryPhaseResultConsumer(namedWriteableRegistry, listener, this, numShards, request.getBatchedReduceSize(), - hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce()); - } - } - return new ArraySearchPhaseResults(numShards) { - @Override - void consumeResult(SearchPhaseResult result) { - super.consumeResult(result); - listener.notifyQueryResult(result.queryResult().getShardIndex()); - } - - @Override - ReducedQueryPhase reduce() { - List resultList = results.asList(); - final ReducedQueryPhase reducePhase = - reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, aggReduceContextBuilder, request.isFinalReduce()); - listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList), - reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); - return reducePhase; - } - }; + int topNSize = getTopDocsSize(request); + int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), numShards) : numShards; + return new QueryPhaseResultConsumer(executor, this, listener, aggReduceContextBuilder, namedWriteableRegistry, + numShards, bufferSize, hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, request.isFinalReduce(), onPartialMergeFailure); } static final class TopDocsStats { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java index e81cf4b74e2..ca9ccfead74 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java @@ -48,8 +48,9 @@ abstract class SearchPhaseResults { /** * Consumes a single shard result * @param result the shards result + * @param next a {@link Runnable} that is executed when the response has been fully consumed */ - abstract void consumeResult(Result result); + abstract void consumeResult(Result result, Runnable next); /** * Returns true iff a result if present for the given shard ID. @@ -65,7 +66,7 @@ abstract class SearchPhaseResults { /** * Reduces the collected results */ - SearchPhaseController.ReducedQueryPhase reduce() { + SearchPhaseController.ReducedQueryPhase reduce() throws Exception { throw new UnsupportedOperationException("reduce is not supported"); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index 1a3596bd725..8be49972faa 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -30,7 +30,6 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregations; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; @@ -174,7 +173,7 @@ abstract class SearchProgressListener { } static List buildSearchShards(List results) { - List lst = results.stream() + List lst = results.stream() .filter(Objects::nonNull) .map(SearchPhaseResult::getSearchShardTarget) .map(e -> new SearchShard(e.getClusterAlias(), e.getShardId())) @@ -182,16 +181,8 @@ abstract class SearchProgressListener { return Collections.unmodifiableList(lst); } - static List buildSearchShards(SearchShardTarget[] results) { - List lst = Arrays.stream(results) - .filter(Objects::nonNull) - .map(e -> new SearchShard(e.getClusterAlias(), e.getShardId())) - .collect(Collectors.toList()); - return Collections.unmodifiableList(lst); - } - static List buildSearchShards(GroupShardsIterator its) { - List lst = StreamSupport.stream(its.spliterator(), false) + List lst = StreamSupport.stream(its.spliterator(), false) .map(e -> new SearchShard(e.getClusterAlias(), e.shardId())) .collect(Collectors.toList()); return Collections.unmodifiableList(lst); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index e8e864ddd1b..26f4afaa1dc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -38,6 +38,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; +import java.util.function.Consumer; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; @@ -59,11 +60,12 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, - ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) { + ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, + Consumer onPartialMergeFailure) { super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, executor, request, listener, shardsIts, timeProvider, clusterState, task, - searchPhaseController.newSearchPhaseResults(task.getProgressListener(), request, shardsIts.size()), - request.getMaxConcurrentShardRequests(), clusters); + searchPhaseController.newSearchPhaseResults(executor, task.getProgressListener(), + request, shardsIts.size(), onPartialMergeFailure), request.getMaxConcurrentShardRequests(), clusters); this.topDocsSize = getTopDocsSize(request); this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo(); this.searchPhaseController = searchPhaseController; diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index a0b07ce4212..4ab0bab1187 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -21,6 +21,7 @@ package org.elasticsearch.action.search; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; @@ -28,6 +29,8 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; +import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -57,6 +60,7 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.ProfileShardResult; import org.elasticsearch.search.profile.SearchProfileShardResults; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.RemoteClusterService; @@ -81,6 +85,7 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.LongSupplier; +import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN; import static org.elasticsearch.action.search.SearchType.DFS_QUERY_THEN_FETCH; import static org.elasticsearch.action.search.SearchType.QUERY_THEN_FETCH; import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; @@ -91,6 +96,7 @@ public class TransportSearchAction extends HandledTransportAction SHARD_COUNT_LIMIT_SETTING = Setting.longSetting( "action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope); + private final NodeClient client; private final ThreadPool threadPool; private final ClusterService clusterService; private final SearchTransportService searchTransportService; @@ -100,11 +106,12 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new); + this.client = client; this.threadPool = threadPool; this.searchPhaseController = searchPhaseController; this.searchTransportService = searchTransportService; @@ -618,12 +625,12 @@ public class TransportSearchAction extends HandledTransportAction cancelTask(task, exc)); break; case QUERY_THEN_FETCH: searchAsyncAction = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, - shardIterators, timeProvider, clusterState, task, clusters); + shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); break; default: throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]"); @@ -632,6 +639,15 @@ public class TransportSearchAction extends HandledTransportAction {})); + } + private static void failIfOverShardCountLimit(ClusterService clusterService, int shardCount) { final long shardCountLimit = clusterService.getClusterSettings().get(SHARD_COUNT_LIMIT_SETTING); if (shardCount > shardCountLimit) { diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index 6309893c95e..2ab392c24eb 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -236,6 +236,18 @@ public final class QuerySearchResult extends SearchPhaseResult { return hasProfileResults; } + public void consumeAll() { + if (hasProfileResults()) { + consumeProfileResult(); + } + if (hasConsumedTopDocs() == false) { + consumeTopDocs(); + } + if (hasAggs()) { + consumeAggs(); + } + } + /** * Sets the finalized profiling results for this query * @param shardResults The finalized profile diff --git a/server/src/main/java/org/elasticsearch/search/suggest/SuggestPhase.java b/server/src/main/java/org/elasticsearch/search/suggest/SuggestPhase.java index 89b1f089581..58df2c190d5 100644 --- a/server/src/main/java/org/elasticsearch/search/suggest/SuggestPhase.java +++ b/server/src/main/java/org/elasticsearch/search/suggest/SuggestPhase.java @@ -65,5 +65,9 @@ public class SuggestPhase implements SearchPhase { throw new ElasticsearchException("I/O exception during suggest phase", e); } } + + static class SortedHits { + int[] docs; + } } diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index b8957b7fef2..b3787bfe1b6 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -282,7 +282,7 @@ public class AbstractSearchAsyncActionTests extends ESTestCase { nodeLookups.add(Tuple.tuple(resultClusterAlias, resultNodeId)); phaseResult.setSearchShardTarget(new SearchShardTarget(resultNodeId, resultShardId, resultClusterAlias, OriginalIndices.NONE)); phaseResult.setShardIndex(i); - phaseResults.consumeResult(phaseResult); + phaseResults.consumeResult(phaseResult, () -> {}); } return phaseResults; } diff --git a/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java b/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java index 8d20289ca1a..836f65d4a71 100644 --- a/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java @@ -35,10 +35,10 @@ import java.util.concurrent.Executor; public class CountedCollectorTests extends ESTestCase { public void testCollect() throws InterruptedException { - AtomicArray results = new AtomicArray<>(randomIntBetween(1, 100)); + ArraySearchPhaseResults consumer = new ArraySearchPhaseResults<>(randomIntBetween(1, 100)); List state = new ArrayList<>(); - int numResultsExpected = randomIntBetween(1, results.length()); - MockSearchPhaseContext context = new MockSearchPhaseContext(results.length()); + int numResultsExpected = randomIntBetween(1, consumer.getAtomicArray().length()); + MockSearchPhaseContext context = new MockSearchPhaseContext(consumer.getAtomicArray().length()); CountDownLatch latch = new CountDownLatch(1); boolean maybeFork = randomBoolean(); Executor executor = (runnable) -> { @@ -49,7 +49,7 @@ public class CountedCollectorTests extends ESTestCase { runnable.run(); } }; - CountedCollector collector = new CountedCollector<>(r -> results.set(r.getShardIndex(), r), numResultsExpected, + CountedCollector collector = new CountedCollector<>(consumer, numResultsExpected, latch::countDown, context); for (int i = 0; i < numResultsExpected; i++) { int shardID = i; @@ -78,7 +78,7 @@ public class CountedCollectorTests extends ESTestCase { } latch.await(); assertEquals(numResultsExpected, state.size()); - + AtomicArray results = consumer.getAtomicArray(); for (int i = 0; i < numResultsExpected; i++) { switch (state.get(i)) { case 0: diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index aab6f0b32ba..dd7ca786c77 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -95,7 +95,7 @@ public class DfsQueryPhaseTests extends ESTestCase { public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext); + }, mockSearchPhaseContext, exc -> {}); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -149,7 +149,7 @@ public class DfsQueryPhaseTests extends ESTestCase { public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext); + }, mockSearchPhaseContext, exc -> {}); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -206,7 +206,7 @@ public class DfsQueryPhaseTests extends ESTestCase { public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext); + }, mockSearchPhaseContext, exc -> {}); assertEquals("dfs_query", phase.getName()); expectThrows(UncheckedIOException.class, phase::run); assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index 0cb81b91266..2efe8197403 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; @@ -52,7 +53,8 @@ public class FetchSearchPhaseTests extends ESTestCase { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); - ArraySearchPhaseResults results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 1); + ArraySearchPhaseResults results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {}); boolean hasHits = randomBoolean(); final int numHits; if (hasHits) { @@ -66,7 +68,7 @@ public class FetchSearchPhaseTests extends ESTestCase { fetchResult.hits(new SearchHits(new SearchHit[] {new SearchHit(42)}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); QueryFetchSearchResult fetchSearchResult = new QueryFetchSearchResult(queryResult, fetchResult); fetchSearchResult.setShardIndex(0); - results.consumeResult(fetchSearchResult); + results.consumeResult(fetchSearchResult, () -> {}); numHits = 1; } else { numHits = 0; @@ -95,7 +97,8 @@ public class FetchSearchPhaseTests extends ESTestCase { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); - ArraySearchPhaseResults results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); + ArraySearchPhaseResults results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, @@ -104,7 +107,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); final SearchContextId ctx2 = new SearchContextId(UUIDs.randomBase64UUID(), 312); queryResult = new QuerySearchResult(ctx2, @@ -113,7 +116,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); queryResult.setShardIndex(1); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override @@ -155,8 +158,8 @@ public class FetchSearchPhaseTests extends ESTestCase { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); - ArraySearchPhaseResults results = - controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); + ArraySearchPhaseResults results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, @@ -165,7 +168,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); SearchContextId ctx2 = new SearchContextId(UUIDs.randomBase64UUID(), 321); queryResult = new QuerySearchResult(ctx2, @@ -174,7 +177,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); queryResult.setShardIndex(1); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override @@ -220,8 +223,8 @@ public class FetchSearchPhaseTests extends ESTestCase { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); - ArraySearchPhaseResults results = controller.newSearchPhaseResults(NOOP, - mockSearchPhaseContext.getRequest(), numHits); + ArraySearchPhaseResults results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP, + mockSearchPhaseContext.getRequest(), numHits, exc -> {}); for (int i = 0; i < numHits; i++) { QuerySearchResult queryResult = new QuerySearchResult(new SearchContextId("", i), new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE)); @@ -229,7 +232,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(i+1, i)}), i), new DocValueFormat[0]); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(i); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); } mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override @@ -278,7 +281,8 @@ public class FetchSearchPhaseTests extends ESTestCase { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); ArraySearchPhaseResults results = - controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); + controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); QuerySearchResult queryResult = new QuerySearchResult(new SearchContextId("", 123), new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE)); @@ -286,7 +290,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); queryResult = new QuerySearchResult(new SearchContextId("", 321), new SearchShardTarget("node2", new ShardId("test", "na", 1), null, OriginalIndices.NONE)); @@ -294,7 +298,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); queryResult.setShardIndex(1); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); AtomicInteger numFetches = new AtomicInteger(0); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override @@ -334,8 +338,8 @@ public class FetchSearchPhaseTests extends ESTestCase { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); - ArraySearchPhaseResults results = - controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); + ArraySearchPhaseResults results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = 1; SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, @@ -344,7 +348,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); SearchContextId ctx2 = new SearchContextId(UUIDs.randomBase64UUID(), 321); queryResult = new QuerySearchResult(ctx2, @@ -353,7 +357,7 @@ public class FetchSearchPhaseTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]); queryResult.size(resultSetSize); queryResult.setShardIndex(1); - results.consumeResult(queryResult); + results.consumeResult(queryResult, () -> {}); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 7b93b9aa25e..3441a3edf0f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -35,12 +35,17 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; @@ -67,16 +72,21 @@ import org.elasticsearch.search.suggest.phrase.PhraseSuggestion; import org.elasticsearch.search.suggest.term.TermSuggestion; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalAggregationTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; import org.junit.Before; +import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -86,20 +96,25 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.elasticsearch.action.search.SearchProgressListener.NOOP; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; -import static org.hamcrest.Matchers.not; public class SearchPhaseControllerTests extends ESTestCase { + private ThreadPool threadPool; + private EsThreadPoolExecutor fixedExecutor; private SearchPhaseController searchPhaseController; private List reductions; @Override protected NamedWriteableRegistry writableRegistry() { - return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables()); + List entries = + new ArrayList<>(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables()); + entries.add(new NamedWriteableRegistry.Entry(InternalAggregation.class, "throwing", InternalThrowing::new)); + return new NamedWriteableRegistry(entries); } @Before @@ -119,6 +134,15 @@ public class SearchPhaseControllerTests extends ESTestCase { BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineTree.EMPTY); }; }); + threadPool = new TestThreadPool(SearchPhaseControllerTests.class.getName()); + fixedExecutor = EsExecutors.newFixed("test", 1, 10, + EsExecutors.daemonThreadFactory("test"), threadPool.getThreadContext()); + } + + @After + public void cleanup() { + fixedExecutor.shutdownNow(); + terminate(threadPool); } public void testSortDocs() { @@ -141,9 +165,15 @@ public class SearchPhaseControllerTests extends ESTestCase { int suggestionSize = suggestion.getEntries().get(0).getOptions().size(); accumulatedLength += suggestionSize; } - ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(true, results.asList(), null, - new SearchPhaseController.TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE), from, size, - reducedCompletionSuggestions).scoreDocs; + List topDocsList = new ArrayList<>(); + for (SearchPhaseResult result : results.asList()) { + QuerySearchResult queryResult = result.queryResult(); + TopDocs topDocs = queryResult.consumeTopDocs().topDocs; + SearchPhaseController.setShardIndex(topDocs, result.getShardIndex()); + topDocsList.add(topDocs); + } + ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(true, topDocsList, + from, size, reducedCompletionSuggestions).scoreDocs; assertThat(sortedDocs.length, equalTo(accumulatedLength)); } @@ -161,14 +191,26 @@ public class SearchPhaseControllerTests extends ESTestCase { from = first.get().queryResult().from(); size = first.get().queryResult().size(); } - SearchPhaseController.TopDocsStats topDocsStats = new SearchPhaseController.TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE); - ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(ignoreFrom, results.asList(), null, topDocsStats, from, size, + List topDocsList = new ArrayList<>(); + for (SearchPhaseResult result : results.asList()) { + QuerySearchResult queryResult = result.queryResult(); + TopDocs topDocs = queryResult.consumeTopDocs().topDocs; + topDocsList.add(topDocs); + SearchPhaseController.setShardIndex(topDocs, result.getShardIndex()); + } + ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(ignoreFrom, topDocsList, from, size, Collections.emptyList()).scoreDocs; results = generateSeededQueryResults(randomSeed, nShards, Collections.emptyList(), queryResultSize, useConstantScore); - SearchPhaseController.TopDocsStats topDocsStats2 = new SearchPhaseController.TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE); - ScoreDoc[] sortedDocs2 = SearchPhaseController.sortDocs(ignoreFrom, results.asList(), null, topDocsStats2, from, size, + topDocsList = new ArrayList<>(); + for (SearchPhaseResult result : results.asList()) { + QuerySearchResult queryResult = result.queryResult(); + TopDocs topDocs = queryResult.consumeTopDocs().topDocs; + topDocsList.add(topDocs); + SearchPhaseController.setShardIndex(topDocs, result.getShardIndex()); + } + ScoreDoc[] sortedDocs2 = SearchPhaseController.sortDocs(ignoreFrom, topDocsList, from, size, Collections.emptyList()).scoreDocs; assertEquals(sortedDocs.length, sortedDocs2.length); for (int i = 0; i < sortedDocs.length; i++) { @@ -176,10 +218,6 @@ public class SearchPhaseControllerTests extends ESTestCase { assertEquals(sortedDocs[i].shardIndex, sortedDocs2[i].shardIndex); assertEquals(sortedDocs[i].score, sortedDocs2[i].score, 0.0f); } - assertEquals(topDocsStats.getMaxScore(), topDocsStats2.getMaxScore(), 0.0f); - assertEquals(topDocsStats.getTotalHits().value, topDocsStats2.getTotalHits().value); - assertEquals(topDocsStats.getTotalHits().relation, topDocsStats2.getTotalHits().relation); - assertEquals(topDocsStats.fetchHits, topDocsStats2.fetchHits); } private AtomicArray generateSeededQueryResults(long seed, int nShards, @@ -200,9 +238,10 @@ public class SearchPhaseControllerTests extends ESTestCase { int nShards = randomIntBetween(1, 20); int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); AtomicArray queryResults = generateQueryResults(nShards, suggestions, queryResultSize, false); - for (int trackTotalHits : new int[] {SearchContext.TRACK_TOTAL_HITS_DISABLED, SearchContext.TRACK_TOTAL_HITS_ACCURATE}) { - SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase( - queryResults.asList(), false, trackTotalHits, InternalAggregationTestCase.emptyReduceContextBuilder(), true); + for (int trackTotalHits : new int[] { SearchContext.TRACK_TOTAL_HITS_DISABLED, SearchContext.TRACK_TOTAL_HITS_ACCURATE }) { + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList(), + new ArrayList<>(), new ArrayList<>(), new SearchPhaseController.TopDocsStats(trackTotalHits), + 0, true, InternalAggregationTestCase.emptyReduceContextBuilder(), true); AtomicArray fetchResults = generateFetchResults(nShards, reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest); InternalSearchResponse mergedResponse = searchPhaseController.merge(false, @@ -246,9 +285,8 @@ public class SearchPhaseControllerTests extends ESTestCase { * Generate random query results received from the provided number of shards, including the provided * number of search hits and randomly generated completion suggestions based on the name and size of the provided ones. * Note that shardIndex is already set to the generated completion suggestions to simulate what - * {@link SearchPhaseController#reducedQueryPhase(Collection, boolean, int, InternalAggregation.ReduceContextBuilder, boolean)} does, - * meaning that the returned query results can be fed directly to - * {@link SearchPhaseController#sortDocs(boolean, Collection, Collection, SearchPhaseController.TopDocsStats, int, int, List)} + * {@link SearchPhaseController#reducedQueryPhase} does, + * meaning that the returned query results can be fed directly to {@link SearchPhaseController#sortDocs} */ private static AtomicArray generateQueryResults(int nShards, List suggestions, int searchHitsSize, boolean useConstantScore) { @@ -364,22 +402,24 @@ public class SearchPhaseControllerTests extends ESTestCase { Strings.EMPTY_ARRAY, "remote", 0, randomBoolean()); } - public void testConsumer() { + public void testConsumer() throws Exception { consumerTestCase(0); } - public void testConsumerWithEmptyResponse() { + public void testConsumerWithEmptyResponse() throws Exception { consumerTestCase(randomIntBetween(1, 5)); } - private void consumerTestCase(int numEmptyResponses) { + private void consumerTestCase(int numEmptyResponses) throws Exception { + long beforeCompletedTasks = fixedExecutor.getCompletedTaskCount(); int numShards = 3 + numEmptyResponses; int bufferSize = randomIntBetween(2, 3); + CountDownLatch latch = new CountDownLatch(numShards); SearchRequest request = randomSearchRequest(); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, 3+numEmptyResponses); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, 3+numEmptyResponses, exc -> {}); if (numEmptyResponses == 0) { assertEquals(0, reductions.size()); } @@ -388,7 +428,7 @@ public class SearchPhaseControllerTests extends ESTestCase { int shardId = 2 + numEmptyResponses; empty.setShardIndex(2+numEmptyResponses); empty.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null, OriginalIndices.NONE)); - consumer.consumeResult(empty); + consumer.consumeResult(empty, latch::countDown); numEmptyResponses --; } @@ -399,7 +439,7 @@ public class SearchPhaseControllerTests extends ESTestCase { InternalAggregations aggs = InternalAggregations.from(singletonList(new InternalMax("test", 1.0D, DocValueFormat.RAW, emptyMap()))); result.aggregations(aggs); result.setShardIndex(0); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), 1), new SearchShardTarget("node", new ShardId("a", "b", 0), null, OriginalIndices.NONE)); @@ -408,7 +448,7 @@ public class SearchPhaseControllerTests extends ESTestCase { aggs = InternalAggregations.from(singletonList(new InternalMax("test", 3.0D, DocValueFormat.RAW, emptyMap()))); result.aggregations(aggs); result.setShardIndex(2); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), 1), new SearchShardTarget("node", new ShardId("a", "b", 0), null, OriginalIndices.NONE)); @@ -417,40 +457,38 @@ public class SearchPhaseControllerTests extends ESTestCase { aggs = InternalAggregations.from(singletonList(new InternalMax("test", 2.0D, DocValueFormat.RAW, emptyMap()))); result.aggregations(aggs); result.setShardIndex(1); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); while (numEmptyResponses > 0) { result = QuerySearchResult.nullInstance(); int shardId = 2 + numEmptyResponses; result.setShardIndex(shardId); result.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null, OriginalIndices.NONE)); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); numEmptyResponses--; } - + latch.await(); final int numTotalReducePhases; if (numShards > bufferSize) { - assertThat(consumer, instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class)); if (bufferSize == 2) { - assertEquals(1, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumReducePhases()); - assertEquals(2, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumBuffered()); + assertEquals(1, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); assertEquals(1, reductions.size()); assertEquals(false, reductions.get(0)); numTotalReducePhases = 2; } else { - assertEquals(0, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumReducePhases()); - assertEquals(3, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumBuffered()); + assertEquals(0, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); assertEquals(0, reductions.size()); numTotalReducePhases = 1; } } else { - assertThat(consumer, not(instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class))); assertEquals(0, reductions.size()); numTotalReducePhases = 1; } SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + int numCompletedTasks = (int) (fixedExecutor.getCompletedTaskCount() - beforeCompletedTasks); + assertEquals(numCompletedTasks, reduce.numReducePhases-1); assertEquals(numTotalReducePhases, reduce.numReducePhases); assertEquals(numTotalReducePhases, reductions.size()); assertAggReduction(request); @@ -462,17 +500,18 @@ public class SearchPhaseControllerTests extends ESTestCase { assertNull(reduce.sortedTopDocs.collapseValues); } - public void testConsumerConcurrently() throws InterruptedException { + public void testConsumerConcurrently() throws Exception { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = randomSearchRequest(); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { int id = i; threads[i] = new Thread(() -> { @@ -488,7 +527,7 @@ public class SearchPhaseControllerTests extends ESTestCase { result.aggregations(aggs); result.setShardIndex(id); result.size(1); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); }); threads[i].start(); @@ -496,6 +535,8 @@ public class SearchPhaseControllerTests extends ESTestCase { for (int i = 0; i < expectedNumResults; i++) { threads[i].join(); } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertAggReduction(request); InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0); @@ -510,15 +551,16 @@ public class SearchPhaseControllerTests extends ESTestCase { assertNull(reduce.sortedTopDocs.collapseValues); } - public void testConsumerOnlyAggs() { + public void testConsumerOnlyAggs() throws Exception { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = randomSearchRequest(); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { int number = randomIntBetween(1, 1000); max.updateAndGet(prev -> Math.max(prev, number)); @@ -531,8 +573,10 @@ public class SearchPhaseControllerTests extends ESTestCase { result.aggregations(aggs); result.setShardIndex(i); result.size(1); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertAggReduction(request); InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0); @@ -546,7 +590,7 @@ public class SearchPhaseControllerTests extends ESTestCase { assertNull(reduce.sortedTopDocs.collapseValues); } - public void testConsumerOnlyHits() { + public void testConsumerOnlyHits() throws Exception { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = randomSearchRequest(); @@ -554,9 +598,10 @@ public class SearchPhaseControllerTests extends ESTestCase { request.source(new SearchSourceBuilder().size(randomIntBetween(1, 10))); } request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { int number = randomIntBetween(1, 1000); max.updateAndGet(prev -> Math.max(prev, number)); @@ -566,8 +611,9 @@ public class SearchPhaseControllerTests extends ESTestCase { new ScoreDoc[] {new ScoreDoc(0, number)}), number), new DocValueFormat[0]); result.setShardIndex(i); result.size(1); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); } + latch.await(); SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertAggReduction(request); assertEquals(1, reduce.sortedTopDocs.scoreDocs.length); @@ -591,47 +637,14 @@ public class SearchPhaseControllerTests extends ESTestCase { } } - public void testNewSearchPhaseResults() { - for (int i = 0; i < 10; i++) { - int expectedNumResults = randomIntBetween(1, 10); - int bufferSize = randomIntBetween(2, 10); - SearchRequest request = new SearchRequest(); - final boolean hasAggs; - if ((hasAggs = randomBoolean())) { - request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); - } - final boolean hasTopDocs; - if ((hasTopDocs = randomBoolean())) { - if (request.source() != null) { - request.source().size(randomIntBetween(1, 100)); - } // no source means size = 10 - } else { - if (request.source() == null) { - request.source(new SearchSourceBuilder().size(0)); - } else { - request.source().size(0); - } - } - request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer - = searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); - if ((hasAggs || hasTopDocs) && expectedNumResults > bufferSize) { - assertThat("expectedNumResults: " + expectedNumResults + " bufferSize: " + bufferSize, - consumer, instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class)); - } else { - assertThat("expectedNumResults: " + expectedNumResults + " bufferSize: " + bufferSize, - consumer, not(instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class))); - } - } - } - - public void testReduceTopNWithFromOffset() { + public void testReduceTopNWithFromOffset() throws Exception { SearchRequest request = new SearchRequest(); request.source(new SearchSourceBuilder().size(5).from(5)); request.setBatchedReduceSize(randomIntBetween(2, 4)); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, 4); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, 4, exc -> {}); int score = 100; + CountDownLatch latch = new CountDownLatch(4); for (int i = 0; i < 4; i++) { QuerySearchResult result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), i), new SearchShardTarget("node", new ShardId("a", "b", i), null, OriginalIndices.NONE)); @@ -644,8 +657,9 @@ public class SearchPhaseControllerTests extends ESTestCase { result.setShardIndex(i); result.size(5); result.from(5); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); } + latch.await(); // 4*3 results = 12 we get result 5 to 10 here with from=5 and size=5 SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); ScoreDoc[] scoreDocs = reduce.sortedTopDocs.scoreDocs; @@ -659,17 +673,18 @@ public class SearchPhaseControllerTests extends ESTestCase { assertEquals(91.0f, scoreDocs[4].score, 0.0f); } - public void testConsumerSortByField() { + public void testConsumerSortByField() throws Exception { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = randomSearchRequest(); int size = randomIntBetween(1, 10); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)}; DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { int number = randomIntBetween(1, 1000); max.updateAndGet(prev -> Math.max(prev, number)); @@ -680,8 +695,9 @@ public class SearchPhaseControllerTests extends ESTestCase { result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); result.setShardIndex(i); result.size(size); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); } + latch.await(); SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertAggReduction(request); assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length); @@ -695,20 +711,21 @@ public class SearchPhaseControllerTests extends ESTestCase { assertNull(reduce.sortedTopDocs.collapseValues); } - public void testConsumerFieldCollapsing() { + public void testConsumerFieldCollapsing() throws Exception { int expectedNumResults = randomIntBetween(30, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = randomSearchRequest(); int size = randomIntBetween(5, 10); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> {}); SortField[] sortFields = {new SortField("field", SortField.Type.STRING)}; BytesRef a = new BytesRef("a"); BytesRef b = new BytesRef("b"); BytesRef c = new BytesRef("c"); Object[] collapseValues = new Object[]{a, b, c}; DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { Object[] values = {randomFrom(collapseValues)}; FieldDoc[] fieldDocs = {new FieldDoc(0, Float.NaN, values)}; @@ -718,8 +735,9 @@ public class SearchPhaseControllerTests extends ESTestCase { result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); result.setShardIndex(i); result.size(size); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); } + latch.await(); SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertAggReduction(request); assertEquals(3, reduce.sortedTopDocs.scoreDocs.length); @@ -735,16 +753,17 @@ public class SearchPhaseControllerTests extends ESTestCase { assertArrayEquals(collapseValues, reduce.sortedTopDocs.collapseValues); } - public void testConsumerSuggestions() { + public void testConsumerSuggestions() throws Exception { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = randomSearchRequest(); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> {}); int maxScoreTerm = -1; int maxScorePhrase = -1; int maxScoreCompletion = -1; + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { QuerySearchResult result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), i), new SearchShardTarget("node", new ShardId("a", "b", i), null, OriginalIndices.NONE)); @@ -792,8 +811,9 @@ public class SearchPhaseControllerTests extends ESTestCase { result.topDocs(new TopDocsAndMaxScore(Lucene.EMPTY_TOP_DOCS, Float.NaN), new DocValueFormat[0]); result.setShardIndex(i); result.size(0); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); } + latch.await(); SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertEquals(3, reduce.suggest.size()); { @@ -827,7 +847,7 @@ public class SearchPhaseControllerTests extends ESTestCase { assertNull(reduce.sortedTopDocs.collapseValues); } - public void testProgressListener() throws InterruptedException { + public void testProgressListener() throws Exception { int expectedNumResults = randomIntBetween(10, 100); for (int bufferSize : new int[] {expectedNumResults, expectedNumResults/2, expectedNumResults/4, 2}) { SearchRequest request = randomSearchRequest(); @@ -861,13 +881,14 @@ public class SearchPhaseControllerTests extends ESTestCase { public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { totalHitsListener.set(totalHits); finalAggsListener.set(aggs); - numReduceListener.incrementAndGet(); + assertEquals(numReduceListener.incrementAndGet(), reducePhase); } }; - ArraySearchPhaseResults consumer = - searchPhaseController.newSearchPhaseResults(progressListener, request, expectedNumResults); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + progressListener, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; + CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { int id = i; threads[i] = new Thread(() -> { @@ -883,13 +904,14 @@ public class SearchPhaseControllerTests extends ESTestCase { result.aggregations(aggs); result.setShardIndex(id); result.size(1); - consumer.consumeResult(result); + consumer.consumeResult(result, latch::countDown); }); threads[i].start(); } for (int i = 0; i < expectedNumResults; i++) { threads[i].join(); } + latch.await(); SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); assertAggReduction(request); InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0); @@ -912,4 +934,94 @@ public class SearchPhaseControllerTests extends ESTestCase { } } + public void testPartialMergeFailure() throws InterruptedException { + int expectedNumResults = randomIntBetween(20, 200); + int bufferSize = randomIntBetween(2, expectedNumResults - 1); + SearchRequest request = new SearchRequest(); + + request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); + request.setBatchedReduceSize(bufferSize); + AtomicBoolean hasConsumedFailure = new AtomicBoolean(); + ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, + NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true)); + CountDownLatch latch = new CountDownLatch(expectedNumResults); + Thread[] threads = new Thread[expectedNumResults]; + int failedIndex = randomIntBetween(0, expectedNumResults-1); + for (int i = 0; i < expectedNumResults; i++) { + final int index = i; + threads[index] = new Thread(() -> { + QuerySearchResult result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), index), + new SearchShardTarget("node", new ShardId("a", "b", index), null, OriginalIndices.NONE)); + result.topDocs(new TopDocsAndMaxScore( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN), + new DocValueFormat[0]); + InternalAggregations aggs = InternalAggregations.from( + Collections.singletonList(new InternalThrowing("test", (failedIndex == index), Collections.emptyMap()))); + result.aggregations(aggs); + result.setShardIndex(index); + result.size(1); + consumer.consumeResult(result, latch::countDown); + }); + threads[index].start(); + } + for (int i = 0; i < expectedNumResults; i++) { + threads[i].join(); + } + latch.await(); + IllegalStateException exc = expectThrows(IllegalStateException.class, () -> consumer.reduce()); + if (exc.getMessage().contains("partial reduce")) { + assertTrue(hasConsumedFailure.get()); + } else { + assertThat(exc.getMessage(), containsString("final reduce")); + } + } + + private static class InternalThrowing extends InternalAggregation { + private final boolean shouldThrow; + + protected InternalThrowing(String name, boolean shouldThrow, Map metadata) { + super(name, metadata); + this.shouldThrow = shouldThrow; + } + + protected InternalThrowing(StreamInput in) throws IOException { + super(in); + this.shouldThrow = in.readBoolean(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeBoolean(shouldThrow); + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + if (aggregations.stream() + .map(agg -> (InternalThrowing) agg) + .anyMatch(agg -> agg.shouldThrow)) { + if (reduceContext.isFinalReduce()) { + throw new IllegalStateException("final reduce"); + } else { + throw new IllegalStateException("partial reduce"); + } + } + return new InternalThrowing(name, false, metadata); + } + + @Override + public Object getProperty(List path) { + return null; + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + throw new IllegalStateException("not implemented"); + } + + @Override + public String getWriteableName() { + return "throwing"; + } + } + } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 2f0e91b630e..c9bb2f592eb 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -57,15 +57,15 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; public class SearchQueryThenFetchAsyncActionTests extends ESTestCase { - public void testBottomFieldSort() throws InterruptedException { + public void testBottomFieldSort() throws Exception { testCase(false); } - public void testScrollDisableBottomFieldSort() throws InterruptedException { + public void testScrollDisableBottomFieldSort() throws Exception { testCase(true); } - private void testCase(boolean withScroll) throws InterruptedException { + private void testCase(boolean withScroll) throws Exception { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(), System::nanoTime); @@ -131,7 +131,7 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase { Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), Collections.emptyMap(), Collections.emptyMap(), controller, EsExecutors.newDirectExecutorService(), searchRequest, null, shardsIter, timeProvider, null, task, - SearchResponse.Clusters.EMPTY) { + SearchResponse.Clusters.EMPTY, exc -> {}) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 88af0989873..09794814369 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1610,7 +1610,7 @@ public class SnapshotResiliencyTests extends ESTestCase { SearchPhaseController searchPhaseController = new SearchPhaseController( writableRegistry(), searchService::aggReduceContextBuilder); actions.put(SearchAction.INSTANCE, - new TransportSearchAction(threadPool, transportService, searchService, + new TransportSearchAction(client, threadPool, transportService, searchService, searchTransportService, searchPhaseController, clusterService, actionFilters, indexNameExpressionResolver)); actions.put(RestoreSnapshotAction.INSTANCE,