Executes incremental reduce in the search thread pool (#58461) (#60275)

This change forks the execution of partial
reduces in the coordinating node to the search thread pool.
It also ensures that partial reduces are executed sequentially
and asynchronously in order to limit the memory and cpu that a
single search request can use but also to avoid blocking a
network thread.
If a partial reduce fails with an exception, the search
request is cancelled and the reporting of the error is
delayed to the start of the fetch phase (when the final
reduce is performed). This ensures that we cleanup the
in-flight search requests before returning an error to
the user.

Closes #53411
Relates #51857
This commit is contained in:
Jim Ferenczi 2020-07-28 13:40:47 +02:00 committed by GitHub
parent d39622e17e
commit 1144534093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 835 additions and 501 deletions

View File

@ -34,6 +34,8 @@ import java.util.Locale;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
@ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 2)
@ -85,8 +87,8 @@ public class RejectionActionIT extends ESIntegTestCase {
if (response instanceof SearchResponse) {
SearchResponse searchResponse = (SearchResponse) response;
for (ShardSearchFailure failure : searchResponse.getShardFailures()) {
assertTrue("got unexpected reason..." + failure.reason(),
failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected"));
assertThat(failure.reason().toLowerCase(Locale.ENGLISH),
anyOf(containsString("cancelled"), containsString("rejected")));
}
} else {
Exception t = (Exception) response;
@ -94,8 +96,8 @@ public class RejectionActionIT extends ESIntegTestCase {
if (unwrap instanceof SearchPhaseExecutionException) {
SearchPhaseExecutionException e = (SearchPhaseExecutionException) unwrap;
for (ShardSearchFailure failure : e.shardFailures()) {
assertTrue("got unexpected reason..." + failure.reason(),
failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected"));
assertThat(failure.reason().toLowerCase(Locale.ENGLISH),
anyOf(containsString("cancelled"), containsString("rejected")));
}
} else if ((unwrap instanceof EsRejectedExecutionException) == false) {
throw new AssertionError("unexpected failure", (Throwable) response);

View File

@ -131,8 +131,8 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase {
testCase((NodeClient) client(), request, sortShards, false);
}
private static void testCase(NodeClient client, SearchRequest request,
List<SearchShard> expectedShards, boolean hasFetchPhase) throws InterruptedException {
private void testCase(NodeClient client, SearchRequest request,
List<SearchShard> expectedShards, boolean hasFetchPhase) throws InterruptedException {
AtomicInteger numQueryResults = new AtomicInteger();
AtomicInteger numQueryFailures = new AtomicInteger();
AtomicInteger numFetchResults = new AtomicInteger();
@ -204,7 +204,6 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase {
}
}, listener);
latch.await();
assertThat(shardsListener.get(), equalTo(expectedShards));
assertThat(numQueryResults.get(), equalTo(searchResponse.get().getSuccessfulShards()));
assertThat(numQueryFailures.get(), equalTo(searchResponse.get().getFailedShards()));

View File

@ -470,12 +470,15 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
protected void onShardResult(Result result, SearchShardIterator shardIt) {
assert result.getShardIndex() != -1 : "shard index is not set";
assert result.getSearchShardTarget() != null : "search shard target must not be null";
successfulOps.incrementAndGet();
results.consumeResult(result);
hasShardResponse.set(true);
if (logger.isTraceEnabled()) {
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
}
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
}
private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
successfulOps.incrementAndGet();
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
// so its ok concurrency wise to miss potentially the shard failures being created because of another failure
// in the #addShardFailure, because by definition, it will happen on *another* shardIndex

View File

@ -39,9 +39,11 @@ class ArraySearchPhaseResults<Result extends SearchPhaseResult> extends SearchPh
return results.asList().stream();
}
void consumeResult(Result result) {
@Override
void consumeResult(Result result, Runnable next) {
assert results.get(result.getShardIndex()) == null : "shardIndex: " + result.getShardIndex() + " is already set";
results.set(result.getShardIndex(), result);
next.run();
}
boolean hasResult(int shardIndex) {

View File

@ -159,8 +159,12 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
}
@Override
void consumeResult(CanMatchResponse result) {
consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
void consumeResult(CanMatchResponse result, Runnable next) {
try {
consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
} finally {
next.run();
}
}
@Override

View File

@ -23,20 +23,18 @@ import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import java.util.function.Consumer;
/**
* This is a simple base class to simplify fan out to shards and collect their results. Each results passed to
* {@link #onResult(SearchPhaseResult)} will be set to the provided result array
* where the given index is used to set the result on the array.
*/
final class CountedCollector<R extends SearchPhaseResult> {
private final Consumer<R> resultConsumer;
private final ArraySearchPhaseResults<R> resultConsumer;
private final CountDown counter;
private final Runnable onFinish;
private final SearchPhaseContext context;
CountedCollector(Consumer<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
CountedCollector(ArraySearchPhaseResults<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
this.resultConsumer = resultConsumer;
this.counter = new CountDown(expectedOps);
this.onFinish = onFinish;
@ -58,11 +56,7 @@ final class CountedCollector<R extends SearchPhaseResult> {
* Sets the result to the given array index and then runs {@link #countDown()}
*/
void onResult(R result) {
try {
resultConsumer.accept(result);
} finally {
countDown();
}
resultConsumer.consumeResult(result, this::countDown);
}
/**

View File

@ -30,6 +30,7 @@ import org.elasticsearch.transport.Transport;
import java.io.IOException;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
/**
@ -51,10 +52,11 @@ final class DfsQueryPhase extends SearchPhase {
DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
SearchPhaseController searchPhaseController,
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context) {
SearchPhaseContext context, Consumer<Exception> onPartialMergeFailure) {
super("dfs_query");
this.progressListener = context.getTask().getProgressListener();
this.queryResult = searchPhaseController.newSearchPhaseResults(progressListener, context.getRequest(), context.getNumShards());
this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener,
context.getRequest(), context.getNumShards(), onPartialMergeFailure);
this.searchPhaseController = searchPhaseController;
this.dfsSearchResults = dfsSearchResults;
this.nextPhaseFactory = nextPhaseFactory;
@ -68,7 +70,7 @@ final class DfsQueryPhase extends SearchPhase {
// to free up memory early
final List<DfsSearchResult> resultList = dfsSearchResults.asList();
final AggregatedDfs dfs = searchPhaseController.aggregateDfs(resultList);
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult::consumeResult,
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult,
resultList.size(),
() -> context.executeNextPhase(this, nextPhaseFactory.apply(queryResult)), context);
for (final DfsSearchResult dfsResult : resultList) {

View File

@ -36,7 +36,6 @@ import org.elasticsearch.search.internal.SearchContextId;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.transport.Transport;
import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;
@ -45,7 +44,7 @@ import java.util.function.BiFunction;
* Then it reaches out to all relevant shards to fetch the topN hits.
*/
final class FetchSearchPhase extends SearchPhase {
private final AtomicArray<FetchSearchResult> fetchResults;
private final ArraySearchPhaseResults<FetchSearchResult> fetchResults;
private final SearchPhaseController searchPhaseController;
private final AtomicArray<SearchPhaseResult> queryResults;
private final BiFunction<InternalSearchResponse, String, SearchPhase> nextPhaseFactory;
@ -73,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase {
throw new IllegalStateException("number of shards must match the length of the query results but doesn't:"
+ context.getNumShards() + "!=" + resultConsumer.getNumShards());
}
this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards());
this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards());
this.searchPhaseController = searchPhaseController;
this.queryResults = resultConsumer.getAtomicArray();
this.nextPhaseFactory = nextPhaseFactory;
@ -102,7 +101,7 @@ final class FetchSearchPhase extends SearchPhase {
});
}
private void innerRun() throws IOException {
private void innerRun() throws Exception {
final int numShards = context.getNumShards();
final boolean isScrollSearch = context.getRequest().scroll() != null;
final List<SearchPhaseResult> phaseResults = queryResults.asList();
@ -117,7 +116,7 @@ final class FetchSearchPhase extends SearchPhase {
final boolean queryAndFetchOptimization = queryResults.length() == 1;
final Runnable finishPhase = ()
-> moveToNextPhase(searchPhaseController, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
queryResults : fetchResults);
queryResults : fetchResults.getAtomicArray());
if (queryAndFetchOptimization) {
assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + phaseResults.isEmpty()
+ "], single result: " + phaseResults.get(0).fetchResult();
@ -137,7 +136,7 @@ final class FetchSearchPhase extends SearchPhase {
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, numShards)
: null;
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(r -> fetchResults.set(r.getShardIndex(), r),
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults,
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
finishPhase, context);
for (int i = 0; i < docIdsToLoad.length; i++) {

View File

@ -0,0 +1,418 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.action.search;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.query.QuerySearchResult;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
import static org.elasticsearch.action.search.SearchPhaseController.setShardIndex;
/**
* A {@link ArraySearchPhaseResults} implementation that incrementally reduces aggregation results
* as shard results are consumed.
* This implementation can be configured to batch up a certain amount of results and reduce
* them asynchronously in the provided {@link Executor} iff the buffer is exhausted.
*/
class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class);
private final Executor executor;
private final SearchPhaseController controller;
private final SearchProgressListener progressListener;
private final ReduceContextBuilder aggReduceContextBuilder;
private final NamedWriteableRegistry namedWriteableRegistry;
private final int topNSize;
private final boolean hasTopDocs;
private final boolean hasAggs;
private final boolean performFinalReduce;
private final PendingMerges pendingMerges;
private final Consumer<Exception> onPartialMergeFailure;
private volatile long aggsMaxBufferSize;
private volatile long aggsCurrentBufferSize;
/**
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
* as shard results are consumed.
*/
QueryPhaseResultConsumer(Executor executor,
SearchPhaseController controller,
SearchProgressListener progressListener,
ReduceContextBuilder aggReduceContextBuilder,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
int bufferSize,
boolean hasTopDocs,
boolean hasAggs,
int trackTotalHitsUpTo,
int topNSize,
boolean performFinalReduce,
Consumer<Exception> onPartialMergeFailure) {
super(expectedResultSize);
this.executor = executor;
this.controller = controller;
this.progressListener = progressListener;
this.aggReduceContextBuilder = aggReduceContextBuilder;
this.namedWriteableRegistry = namedWriteableRegistry;
this.topNSize = topNSize;
this.pendingMerges = new PendingMerges(bufferSize, trackTotalHitsUpTo);
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
this.performFinalReduce = performFinalReduce;
this.onPartialMergeFailure = onPartialMergeFailure;
}
@Override
void consumeResult(SearchPhaseResult result, Runnable next) {
super.consumeResult(result, () -> {});
QuerySearchResult querySearchResult = result.queryResult();
progressListener.notifyQueryResult(querySearchResult.getShardIndex());
pendingMerges.consume(querySearchResult, next);
}
@Override
SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
if (pendingMerges.hasPendingMerges()) {
throw new AssertionError("partial reduce in-flight");
} else if (pendingMerges.hasFailure()) {
throw pendingMerges.getFailure();
}
logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize);
// ensure consistent ordering
pendingMerges.sortBuffer();
final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats();
final List<TopDocs> topDocsList = pendingMerges.consumeTopDocs();
final List<InternalAggregations> aggsList = pendingMerges.consumeAggs();
SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList,
topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce);
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
}
private MergeResult partialReduce(MergeTask task,
TopDocsStats topDocsStats,
MergeResult lastMerge,
int numReducePhases) {
final QuerySearchResult[] toConsume = task.consumeBuffer();
if (toConsume == null) {
// the task is cancelled
return null;
}
// ensure consistent ordering
Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex));
for (QuerySearchResult result : toConsume) {
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
}
final TopDocs newTopDocs;
if (hasTopDocs) {
List<TopDocs> topDocsList = new ArrayList<>();
if (lastMerge != null) {
topDocsList.add(lastMerge.reducedTopDocs);
}
for (QuerySearchResult result : toConsume) {
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
newTopDocs = mergeTopDocs(topDocsList,
// we have to merge here in the same way we collect on a shard
topNSize, 0);
} else {
newTopDocs = null;
}
final DelayableWriteable.Serialized<InternalAggregations> newAggs;
if (hasAggs) {
List<InternalAggregations> aggsList = new ArrayList<>();
if (lastMerge != null) {
aggsList.add(lastMerge.reducedAggs.expand());
}
for (QuerySearchResult result : toConsume) {
aggsList.add(result.consumeAggs().expand());
}
InternalAggregations result = InternalAggregations.topLevelReduce(aggsList,
aggReduceContextBuilder.forPartialReduction());
newAggs = DelayableWriteable.referencing(result).asSerialized(InternalAggregations::readFrom, namedWriteableRegistry);
long previousBufferSize = aggsCurrentBufferSize;
aggsCurrentBufferSize = newAggs.ramBytesUsed();
aggsMaxBufferSize = Math.max(aggsCurrentBufferSize, aggsMaxBufferSize);
logger.trace("aggs partial reduction [{}->{}] max [{}]",
previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize);
} else {
newAggs = null;
}
List<SearchShard> processedShards = new ArrayList<>(task.emptyResults);
if (lastMerge != null) {
processedShards.addAll(lastMerge.processedShards);
}
for (QuerySearchResult result : toConsume) {
SearchShardTarget target = result.getSearchShardTarget();
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
}
progressListener.onPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
return new MergeResult(processedShards, newTopDocs, newAggs);
}
public int getNumReducePhases() {
return pendingMerges.numReducePhases;
}
private class PendingMerges {
private final int bufferSize;
private int index;
private final QuerySearchResult[] buffer;
private final List<SearchShard> emptyResults = new ArrayList<>();
private final TopDocsStats topDocsStats;
private MergeResult mergeResult;
private final ArrayDeque<MergeTask> queue = new ArrayDeque<>();
private final AtomicReference<MergeTask> runningTask = new AtomicReference<>();
private final AtomicReference<Exception> failure = new AtomicReference<>();
private boolean hasPartialReduce;
private int numReducePhases;
PendingMerges(int bufferSize, int trackTotalHitsUpTo) {
this.bufferSize = bufferSize;
this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo);
this.buffer = new QuerySearchResult[bufferSize];
}
public boolean hasFailure() {
return failure.get() != null;
}
public synchronized boolean hasPendingMerges() {
return queue.isEmpty() == false || runningTask.get() != null;
}
public synchronized void sortBuffer() {
if (index > 0) {
Arrays.sort(buffer, 0, index, Comparator.comparingInt(QuerySearchResult::getShardIndex));
}
}
public void consume(QuerySearchResult result, Runnable next) {
boolean executeNextImmediately = true;
synchronized (this) {
if (hasFailure() || result.isNull()) {
result.consumeAll();
if (result.isNull()) {
SearchShardTarget target = result.getSearchShardTarget();
emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
}
} else {
// add one if a partial merge is pending
int size = index + (hasPartialReduce ? 1 : 0);
if (size >= bufferSize) {
hasPartialReduce = true;
executeNextImmediately = false;
QuerySearchResult[] clone = new QuerySearchResult[index];
System.arraycopy(buffer, 0, clone, 0, index);
MergeTask task = new MergeTask(clone, new ArrayList<>(emptyResults), next);
Arrays.fill(buffer, null);
emptyResults.clear();
index = 0;
queue.add(task);
tryExecuteNext();
}
buffer[index++] = result;
}
}
if (executeNextImmediately) {
next.run();
}
}
private void onMergeFailure(Exception exc) {
synchronized (this) {
if (failure.get() != null) {
return;
}
failure.compareAndSet(null, exc);
MergeTask task = runningTask.get();
if (task != null) {
runningTask.compareAndSet(task, null);
task.cancel();
}
queue.stream().forEach(MergeTask::cancel);
queue.clear();
mergeResult = null;
}
onPartialMergeFailure.accept(exc);
}
private void onAfterMerge(MergeTask task, MergeResult newResult) {
synchronized (this) {
runningTask.compareAndSet(task, null);
mergeResult = newResult;
}
task.consumeListener();
}
private void tryExecuteNext() {
final MergeTask task;
synchronized (this) {
if (queue.isEmpty()
|| failure.get() != null
|| runningTask.get() != null) {
return;
}
task = queue.poll();
runningTask.compareAndSet(null, task);
}
executor.execute(new AbstractRunnable() {
@Override
protected void doRun() {
final MergeResult newMerge;
try {
newMerge = partialReduce(task, topDocsStats, mergeResult, ++numReducePhases);
} catch (Exception t) {
onMergeFailure(t);
return;
}
onAfterMerge(task, newMerge);
tryExecuteNext();
}
@Override
public void onFailure(Exception exc) {
onMergeFailure(exc);
}
});
}
public TopDocsStats consumeTopDocsStats() {
for (int i = 0; i < index; i++) {
QuerySearchResult result = buffer[i];
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
}
return topDocsStats;
}
public List<TopDocs> consumeTopDocs() {
if (hasTopDocs == false) {
return Collections.emptyList();
}
List<TopDocs> topDocsList = new ArrayList<>();
if (mergeResult != null) {
topDocsList.add(mergeResult.reducedTopDocs);
}
for (int i = 0; i < index; i++) {
QuerySearchResult result = buffer[i];
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
return topDocsList;
}
public List<InternalAggregations> consumeAggs() {
if (hasAggs == false) {
return Collections.emptyList();
}
List<InternalAggregations> aggsList = new ArrayList<>();
if (mergeResult != null) {
aggsList.add(mergeResult.reducedAggs.expand());
}
for (int i = 0; i < index; i++) {
QuerySearchResult result = buffer[i];
aggsList.add(result.consumeAggs().expand());
}
return aggsList;
}
public Exception getFailure() {
return failure.get();
}
}
private static class MergeResult {
private final List<SearchShard> processedShards;
private final TopDocs reducedTopDocs;
private final DelayableWriteable.Serialized<InternalAggregations> reducedAggs;
private MergeResult(List<SearchShard> processedShards, TopDocs reducedTopDocs,
DelayableWriteable.Serialized<InternalAggregations> reducedAggs) {
this.processedShards = processedShards;
this.reducedTopDocs = reducedTopDocs;
this.reducedAggs = reducedAggs;
}
}
private static class MergeTask {
private final List<SearchShard> emptyResults;
private QuerySearchResult[] buffer;
private Runnable next;
private MergeTask(QuerySearchResult[] buffer, List<SearchShard> emptyResults, Runnable next) {
this.buffer = buffer;
this.emptyResults = emptyResults;
this.next = next;
}
public synchronized QuerySearchResult[] consumeBuffer() {
QuerySearchResult[] toRet = buffer;
buffer = null;
return toRet;
}
public synchronized void consumeListener() {
if (next != null) {
next.run();
next = null;
}
}
public synchronized void cancel() {
consumeBuffer();
consumeListener();
}
}
}

View File

@ -33,10 +33,12 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Consumer;
final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<DfsSearchResult> {
private final SearchPhaseController searchPhaseController;
private final Consumer<Exception> onPartialMergeFailure;
SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService,
final BiFunction<String, String, Transport.Connection> nodeIdToConnection,
@ -46,12 +48,14 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
final SearchRequest request, final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider,
final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) {
final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters,
Consumer<Exception> onPartialMergeFailure) {
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
executor, request, listener,
shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()),
request.getMaxConcurrentShardRequests(), clusters);
this.searchPhaseController = searchPhaseController;
this.onPartialMergeFailure = onPartialMergeFailure;
SearchProgressListener progressListener = task.getProgressListener();
SearchSourceBuilder sourceBuilder = request.source();
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
@ -68,6 +72,6 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
@Override
protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> results, final SearchPhaseContext context) {
return new DfsQueryPhase(results.getAtomicArray(), searchPhaseController, (queryResults) ->
new FetchSearchPhase(queryResults, searchPhaseController, context, clusterState()), context);
new FetchSearchPhase(queryResults, searchPhaseController, context, clusterState()), context, onPartialMergeFailure);
}
}

View File

@ -22,8 +22,6 @@ package org.elasticsearch.action.search;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.FieldDoc;
@ -37,7 +35,6 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.TotalHits.Relation;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.elasticsearch.common.collect.HppcMaps;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.search.DocValueFormat;
@ -45,7 +42,6 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregations;
@ -63,18 +59,18 @@ import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
public final class SearchPhaseController {
private static final Logger logger = LogManager.getLogger(SearchPhaseController.class);
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
private final NamedWriteableRegistry namedWriteableRegistry;
@ -151,77 +147,50 @@ public final class SearchPhaseController {
*
* @param ignoreFrom Whether to ignore the from and sort all hits in each shard result.
* Enabled only for scroll search, because that only retrieves hits of length 'size' in the query phase.
* @param results the search phase results to obtain the sort docs from
* @param bufferedTopDocs the pre-consumed buffered top docs
* @param topDocsStats the top docs stats to fill
* @param topDocs the buffered top docs
* @param from the offset into the search results top docs
* @param size the number of hits to return from the merged top docs
*/
static SortedTopDocs sortDocs(boolean ignoreFrom, Collection<? extends SearchPhaseResult> results,
final Collection<TopDocs> bufferedTopDocs, final TopDocsStats topDocsStats, int from, int size,
static SortedTopDocs sortDocs(boolean ignoreFrom, final Collection<TopDocs> topDocs, int from, int size,
List<CompletionSuggestion> reducedCompletionSuggestions) {
if (results.isEmpty()) {
if (topDocs.isEmpty() && reducedCompletionSuggestions.isEmpty()) {
return SortedTopDocs.EMPTY;
}
final Collection<TopDocs> topDocs = bufferedTopDocs == null ? new ArrayList<>() : bufferedTopDocs;
for (SearchPhaseResult sortedResult : results) { // TODO we can move this loop into the reduce call to only loop over this once
/* We loop over all results once, group together the completion suggestions if there are any and collect relevant
* top docs results. Each top docs gets it's shard index set on all top docs to simplify top docs merging down the road
* this allowed to remove a single shared optimization code here since now we don't materialized a dense array of
* top docs anymore but instead only pass relevant results / top docs to the merge method*/
QuerySearchResult queryResult = sortedResult.queryResult();
if (queryResult.hasConsumedTopDocs() == false) { // already consumed?
final TopDocsAndMaxScore td = queryResult.consumeTopDocs();
assert td != null;
topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly());
// make sure we set the shard index before we add it - the consumer didn't do that yet
if (td.topDocs.scoreDocs.length > 0) {
setShardIndex(td.topDocs, queryResult.getShardIndex());
topDocs.add(td.topDocs);
final TopDocs mergedTopDocs = mergeTopDocs(topDocs, size, ignoreFrom ? 0 : from);
final ScoreDoc[] mergedScoreDocs = mergedTopDocs == null ? EMPTY_DOCS : mergedTopDocs.scoreDocs;
ScoreDoc[] scoreDocs = mergedScoreDocs;
if (reducedCompletionSuggestions.isEmpty() == false) {
int numSuggestDocs = 0;
for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) {
assert completionSuggestion != null;
numSuggestDocs += completionSuggestion.getOptions().size();
}
scoreDocs = new ScoreDoc[mergedScoreDocs.length + numSuggestDocs];
System.arraycopy(mergedScoreDocs, 0, scoreDocs, 0, mergedScoreDocs.length);
int offset = mergedScoreDocs.length;
for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) {
for (CompletionSuggestion.Entry.Option option : completionSuggestion.getOptions()) {
scoreDocs[offset++] = option.getDoc();
}
}
}
final boolean hasHits = (reducedCompletionSuggestions.isEmpty() && topDocs.isEmpty()) == false;
if (hasHits) {
final TopDocs mergedTopDocs = mergeTopDocs(topDocs, size, ignoreFrom ? 0 : from);
final ScoreDoc[] mergedScoreDocs = mergedTopDocs == null ? EMPTY_DOCS : mergedTopDocs.scoreDocs;
ScoreDoc[] scoreDocs = mergedScoreDocs;
if (reducedCompletionSuggestions.isEmpty() == false) {
int numSuggestDocs = 0;
for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) {
assert completionSuggestion != null;
numSuggestDocs += completionSuggestion.getOptions().size();
}
scoreDocs = new ScoreDoc[mergedScoreDocs.length + numSuggestDocs];
System.arraycopy(mergedScoreDocs, 0, scoreDocs, 0, mergedScoreDocs.length);
int offset = mergedScoreDocs.length;
for (CompletionSuggestion completionSuggestion : reducedCompletionSuggestions) {
for (CompletionSuggestion.Entry.Option option : completionSuggestion.getOptions()) {
scoreDocs[offset++] = option.getDoc();
}
}
boolean isSortedByField = false;
SortField[] sortFields = null;
String collapseField = null;
Object[] collapseValues = null;
if (mergedTopDocs instanceof TopFieldDocs) {
TopFieldDocs fieldDocs = (TopFieldDocs) mergedTopDocs;
sortFields = fieldDocs.fields;
if (fieldDocs instanceof CollapseTopFieldDocs) {
isSortedByField = (fieldDocs.fields.length == 1 && fieldDocs.fields[0].getType() == SortField.Type.SCORE) == false;
CollapseTopFieldDocs collapseTopFieldDocs = (CollapseTopFieldDocs) fieldDocs;
collapseField = collapseTopFieldDocs.field;
collapseValues = collapseTopFieldDocs.collapseValues;
} else {
isSortedByField = true;
}
boolean isSortedByField = false;
SortField[] sortFields = null;
String collapseField = null;
Object[] collapseValues = null;
if (mergedTopDocs instanceof TopFieldDocs) {
TopFieldDocs fieldDocs = (TopFieldDocs) mergedTopDocs;
sortFields = fieldDocs.fields;
if (fieldDocs instanceof CollapseTopFieldDocs) {
isSortedByField = (fieldDocs.fields.length == 1 && fieldDocs.fields[0].getType() == SortField.Type.SCORE) == false;
CollapseTopFieldDocs collapseTopFieldDocs = (CollapseTopFieldDocs) fieldDocs;
collapseField = collapseTopFieldDocs.field;
collapseValues = collapseTopFieldDocs.collapseValues;
} else {
isSortedByField = true;
}
}
return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, collapseField, collapseValues);
} else {
// no relevant docs
return SortedTopDocs.EMPTY;
}
return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, collapseField, collapseValues);
}
static TopDocs mergeTopDocs(Collection<TopDocs> results, int topN, int from) {
@ -251,7 +220,7 @@ public final class SearchPhaseController {
return mergedTopDocs;
}
private static void setShardIndex(TopDocs topDocs, int shardIndex) {
static void setShardIndex(TopDocs topDocs, int shardIndex) {
assert topDocs.scoreDocs.length == 0 || topDocs.scoreDocs[0].shardIndex == -1 : "shardIndex is already set";
for (ScoreDoc doc : topDocs.scoreDocs) {
doc.shardIndex = shardIndex;
@ -409,38 +378,38 @@ public final class SearchPhaseController {
throw new UnsupportedOperationException("Scroll requests don't have aggs");
}
};
return reducedQueryPhase(queryResults, true, SearchContext.TRACK_TOTAL_HITS_ACCURATE, aggReduceContextBuilder, true);
final TopDocsStats topDocsStats = new TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
final List<TopDocs> topDocs = new ArrayList<>();
for (SearchPhaseResult sortedResult : queryResults) {
QuerySearchResult queryResult = sortedResult.queryResult();
final TopDocsAndMaxScore td = queryResult.consumeTopDocs();
assert td != null;
topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly());
// make sure we set the shard index before we add it - the consumer didn't do that yet
if (td.topDocs.scoreDocs.length > 0) {
setShardIndex(td.topDocs, queryResult.getShardIndex());
topDocs.add(td.topDocs);
}
}
return reducedQueryPhase(queryResults, Collections.emptyList(), topDocs, topDocsStats,
0, true, aggReduceContextBuilder, true);
}
/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
boolean isScrollRequest, int trackTotalHitsUpTo,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHitsUpTo),
0, isScrollRequest, aggReduceContextBuilder, performFinalReduce);
}
/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
* @param bufferedAggs a list of pre-collected / buffered aggregations. if this list is non-null all aggregations have been consumed
* from all non-null query results.
* @param bufferedTopDocs a list of pre-collected / buffered top docs. if this list is non-null all top docs have been consumed
* from all non-null query results.
* @param bufferedAggs a list of pre-collected aggregations.
* @param bufferedTopDocs a list of pre-collected top docs.
* @param numReducePhases the number of non-final reduce phases applied to the query results.
* @see QuerySearchResult#consumeAggs()
* @see QuerySearchResult#consumeProfileResult()
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<DelayableWriteable<InternalAggregations>> bufferedAggs,
List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<InternalAggregations> bufferedAggs,
List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
numReducePhases++; // increment for this phase
if (queryResults.isEmpty()) { // early terminate we have nothing to reduce
@ -460,22 +429,6 @@ public final class SearchPhaseController {
final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult();
final boolean hasSuggest = firstResult.suggest() != null;
final boolean hasProfileResults = firstResult.hasProfileResults();
final boolean consumeAggs;
final List<DelayableWriteable<InternalAggregations>> aggregationsList;
if (bufferedAggs != null) {
consumeAggs = false;
// we already have results from intermediate reduces and just need to perform the final reduce
assert firstResult.hasAggs() : "firstResult has no aggs but we got non null buffered aggs?";
aggregationsList = bufferedAggs;
} else if (firstResult.hasAggs()) {
// the number of shards was less than the buffer size so we reduce agg results directly
aggregationsList = new ArrayList<>(queryResults.size());
consumeAggs = true;
} else {
// no aggregations
aggregationsList = Collections.emptyList();
consumeAggs = false;
}
// count the total (we use the query result provider here, since we might not get any hits (we scrolled past them))
final Map<String, List<Suggestion>> groupedSuggestions = hasSuggest ? new HashMap<>() : Collections.emptyMap();
@ -499,8 +452,8 @@ public final class SearchPhaseController {
}
}
}
if (consumeAggs) {
aggregationsList.add(result.consumeAggs());
if (bufferedTopDocs.isEmpty() == false) {
assert result.hasConsumedTopDocs() : "firstResult has no aggs but we got non null buffered aggs?";
}
if (hasProfileResults) {
String key = result.getSearchShardTarget().toString();
@ -516,31 +469,19 @@ public final class SearchPhaseController {
reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
}
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, aggregationsList);
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs);
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size,
reducedCompletionSuggestions);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, bufferedTopDocs, from, size, reducedCompletionSuggestions);
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
topDocsStats.timedOut, topDocsStats.terminatedEarly, reducedSuggest, aggregations, shardResults, sortedTopDocs,
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}
private static InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<DelayableWriteable<InternalAggregations>> aggregationsList
) {
/*
* Parse the aggregations, clearing the list as we go so bits backing
* the DelayedWriteable can be collected immediately.
*/
List<InternalAggregations> toReduce = new ArrayList<>(aggregationsList.size());
for (int i = 0; i < aggregationsList.size(); i++) {
toReduce.add(aggregationsList.get(i).expand());
aggregationsList.set(i, null);
}
return aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce,
private static InternalAggregations reduceAggs(InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<InternalAggregations> toReduce) {
return toReduce.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
}
@ -617,199 +558,23 @@ public final class SearchPhaseController {
}
}
/**
* A {@link ArraySearchPhaseResults} implementation
* that incrementally reduces aggregation results as shard results are consumed.
* This implementation can be configured to batch up a certain amount of results and only reduce them
* iff the buffer is exhausted.
*/
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private final NamedWriteableRegistry namedWriteableRegistry;
private final SearchShardTarget[] processedShards;
private final DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer;
private final TopDocs[] topDocsBuffer;
private final boolean hasAggs;
private final boolean hasTopDocs;
private final int bufferSize;
private int index;
private final SearchPhaseController controller;
private final SearchProgressListener progressListener;
private int numReducePhases = 0;
private final TopDocsStats topDocsStats;
private final int topNSize;
private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
private final boolean performFinalReduce;
private long aggsCurrentBufferSize;
private long aggsMaxBufferSize;
/**
* Creates a new {@link QueryPhaseResultConsumer}
* @param progressListener a progress listener to be notified when a successful response is received
* and when a partial or final reduce has completed.
* @param controller a controller instance to reduce the query response objects
* @param expectedResultSize the expected number of query results. Corresponds to the number of shards queried
* @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results
* the buffer is used to incrementally reduce aggregation results before all shards responded.
*/
private QueryPhaseResultConsumer(NamedWriteableRegistry namedWriteableRegistry, SearchProgressListener progressListener,
SearchPhaseController controller,
int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
int trackTotalHitsUpTo, int topNSize,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
super(expectedResultSize);
this.namedWriteableRegistry = namedWriteableRegistry;
if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
}
if (expectedResultSize <= bufferSize) {
throw new IllegalArgumentException("buffer size must be less than the expected result size");
}
if (hasAggs == false && hasTopDocs == false) {
throw new IllegalArgumentException("either aggs or top docs must be present");
}
this.controller = controller;
this.progressListener = progressListener;
this.processedShards = new SearchShardTarget[expectedResultSize];
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
@SuppressWarnings("unchecked")
DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer = new DelayableWriteable.Serialized[hasAggs ? bufferSize : 0];
this.aggsBuffer = aggsBuffer;
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
this.bufferSize = bufferSize;
this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo);
this.topNSize = topNSize;
this.aggReduceContextBuilder = aggReduceContextBuilder;
this.performFinalReduce = performFinalReduce;
}
@Override
public void consumeResult(SearchPhaseResult result) {
super.consumeResult(result);
QuerySearchResult queryResult = result.queryResult();
consumeInternal(queryResult);
progressListener.notifyQueryResult(queryResult.getShardIndex());
}
private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) {
if (index == bufferSize) {
DelayableWriteable.Serialized<InternalAggregations> reducedAggs = null;
if (hasAggs) {
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].expand());
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
}
InternalAggregations reduced =
InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction());
reducedAggs = aggsBuffer[0] = DelayableWriteable.referencing(reduced)
.asSerialized(InternalAggregations::readFrom, namedWriteableRegistry);
long previousBufferSize = aggsCurrentBufferSize;
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
aggsCurrentBufferSize = aggsBuffer[0].ramBytesUsed();
logger.trace("aggs partial reduction [{}->{}] max [{}]",
previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize);
}
if (hasTopDocs) {
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
// we have to merge here in the same way we collect on a shard
topNSize, 0);
Arrays.fill(topDocsBuffer, null);
topDocsBuffer[0] = reducedTopDocs;
}
numReducePhases++;
index = 1;
if (hasAggs || hasTopDocs) {
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
topDocsStats.getTotalHits(), reducedAggs, numReducePhases);
}
}
final int i = index++;
if (hasAggs) {
aggsBuffer[i] = querySearchResult.consumeAggs().asSerialized(InternalAggregations::readFrom, namedWriteableRegistry);
aggsCurrentBufferSize += aggsBuffer[i].ramBytesUsed();
}
if (hasTopDocs) {
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
topDocsStats.add(topDocs, querySearchResult.searchTimedOut(), querySearchResult.terminatedEarly());
setShardIndex(topDocs.topDocs, querySearchResult.getShardIndex());
topDocsBuffer[i] = topDocs.topDocs;
}
}
processedShards[querySearchResult.getShardIndex()] = querySearchResult.getSearchShardTarget();
}
private synchronized List<DelayableWriteable<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList((DelayableWriteable<InternalAggregations>[]) aggsBuffer).subList(0, index) : null;
}
private synchronized List<TopDocs> getRemainingTopDocs() {
return hasTopDocs ? Arrays.asList(topDocsBuffer).subList(0, index) : null;
}
@Override
public ReducedQueryPhase reduce() {
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize);
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false,
aggReduceContextBuilder, performFinalReduce);
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
}
/**
* Returns the number of buffered results
*/
int getNumBuffered() {
return index;
}
int getNumReducePhases() { return numReducePhases; }
}
/**
* Returns a new ArraySearchPhaseResults instance. This might return an instance that reduces search responses incrementally.
*/
ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressListener listener,
ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(Executor executor,
SearchProgressListener listener,
SearchRequest request,
int numShards) {
int numShards,
Consumer<Exception> onPartialMergeFailure) {
SearchSourceBuilder source = request.source();
boolean isScrollRequest = request.scroll() != null;
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final int trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = requestToAggReduceContextBuilder.apply(request);
if (isScrollRequest == false && (hasAggs || hasTopDocs)) {
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
if (request.getBatchedReduceSize() < numShards) {
int topNSize = getTopDocsSize(request);
// only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(namedWriteableRegistry, listener, this, numShards, request.getBatchedReduceSize(),
hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
}
}
return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
@Override
void consumeResult(SearchPhaseResult result) {
super.consumeResult(result);
listener.notifyQueryResult(result.queryResult().getShardIndex());
}
@Override
ReducedQueryPhase reduce() {
List<SearchPhaseResult> resultList = results.asList();
final ReducedQueryPhase reducePhase =
reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, aggReduceContextBuilder, request.isFinalReduce());
listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
}
};
int topNSize = getTopDocsSize(request);
int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), numShards) : numShards;
return new QueryPhaseResultConsumer(executor, this, listener, aggReduceContextBuilder, namedWriteableRegistry,
numShards, bufferSize, hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, request.isFinalReduce(), onPartialMergeFailure);
}
static final class TopDocsStats {

View File

@ -48,8 +48,9 @@ abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
/**
* Consumes a single shard result
* @param result the shards result
* @param next a {@link Runnable} that is executed when the response has been fully consumed
*/
abstract void consumeResult(Result result);
abstract void consumeResult(Result result, Runnable next);
/**
* Returns <code>true</code> iff a result if present for the given shard ID.
@ -65,7 +66,7 @@ abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
/**
* Reduces the collected results
*/
SearchPhaseController.ReducedQueryPhase reduce() {
SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
throw new UnsupportedOperationException("reduce is not supported");
}
}

View File

@ -30,7 +30,6 @@ import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
@ -174,7 +173,7 @@ abstract class SearchProgressListener {
}
static List<SearchShard> buildSearchShards(List<? extends SearchPhaseResult> results) {
List<SearchShard> lst = results.stream()
List<SearchShard> lst = results.stream()
.filter(Objects::nonNull)
.map(SearchPhaseResult::getSearchShardTarget)
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
@ -182,16 +181,8 @@ abstract class SearchProgressListener {
return Collections.unmodifiableList(lst);
}
static List<SearchShard> buildSearchShards(SearchShardTarget[] results) {
List<SearchShard> lst = Arrays.stream(results)
.filter(Objects::nonNull)
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
.collect(Collectors.toList());
return Collections.unmodifiableList(lst);
}
static List<SearchShard> buildSearchShards(GroupShardsIterator<SearchShardIterator> its) {
List<SearchShard> lst = StreamSupport.stream(its.spliterator(), false)
List<SearchShard> lst = StreamSupport.stream(its.spliterator(), false)
.map(e -> new SearchShard(e.getClusterAlias(), e.shardId()))
.collect(Collectors.toList());
return Collections.unmodifiableList(lst);

View File

@ -38,6 +38,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize;
@ -59,11 +60,12 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
final SearchRequest request, final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) {
ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters,
Consumer<Exception> onPartialMergeFailure) {
super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
executor, request, listener, shardsIts, timeProvider, clusterState, task,
searchPhaseController.newSearchPhaseResults(task.getProgressListener(), request, shardsIts.size()),
request.getMaxConcurrentShardRequests(), clusters);
searchPhaseController.newSearchPhaseResults(executor, task.getProgressListener(),
request, shardsIts.size(), onPartialMergeFailure), request.getMaxConcurrentShardRequests(), clusters);
this.topDocsSize = getTopDocsSize(request);
this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
this.searchPhaseController = searchPhaseController;

View File

@ -21,6 +21,7 @@ package org.elasticsearch.action.search;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
@ -28,6 +29,8 @@ import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
@ -57,6 +60,7 @@ import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.ProfileShardResult;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.RemoteClusterService;
@ -81,6 +85,7 @@ import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.LongSupplier;
import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.elasticsearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.elasticsearch.action.search.SearchType.QUERY_THEN_FETCH;
import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;
@ -91,6 +96,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
public static final Setting<Long> SHARD_COUNT_LIMIT_SETTING = Setting.longSetting(
"action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope);
private final NodeClient client;
private final ThreadPool threadPool;
private final ClusterService clusterService;
private final SearchTransportService searchTransportService;
@ -100,11 +106,12 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
private final IndexNameExpressionResolver indexNameExpressionResolver;
@Inject
public TransportSearchAction(ThreadPool threadPool, TransportService transportService, SearchService searchService,
SearchTransportService searchTransportService, SearchPhaseController searchPhaseController,
ClusterService clusterService, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver) {
public TransportSearchAction(NodeClient client, ThreadPool threadPool, TransportService transportService,
SearchService searchService, SearchTransportService searchTransportService,
SearchPhaseController searchPhaseController, ClusterService clusterService,
ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver) {
super(SearchAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchRequest>) SearchRequest::new);
this.client = client;
this.threadPool = threadPool;
this.searchPhaseController = searchPhaseController;
this.searchTransportService = searchTransportService;
@ -618,12 +625,12 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
case DFS_QUERY_THEN_FETCH:
searchAsyncAction = new SearchDfsQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup,
aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener,
shardIterators, timeProvider, clusterState, task, clusters);
shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc));
break;
case QUERY_THEN_FETCH:
searchAsyncAction = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup,
aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener,
shardIterators, timeProvider, clusterState, task, clusters);
shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc));
break;
default:
throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]");
@ -632,6 +639,15 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
}
}
private void cancelTask(SearchTask task, Exception exc) {
String errorMsg = exc.getMessage() != null ? exc.getMessage() : "";
CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(new TaskId(client.getLocalNodeId(), task.getId()))
.setReason("Fatal failure during search: " + errorMsg);
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
}
private static void failIfOverShardCountLimit(ClusterService clusterService, int shardCount) {
final long shardCountLimit = clusterService.getClusterSettings().get(SHARD_COUNT_LIMIT_SETTING);
if (shardCount > shardCountLimit) {

View File

@ -236,6 +236,18 @@ public final class QuerySearchResult extends SearchPhaseResult {
return hasProfileResults;
}
public void consumeAll() {
if (hasProfileResults()) {
consumeProfileResult();
}
if (hasConsumedTopDocs() == false) {
consumeTopDocs();
}
if (hasAggs()) {
consumeAggs();
}
}
/**
* Sets the finalized profiling results for this query
* @param shardResults The finalized profile

View File

@ -65,5 +65,9 @@ public class SuggestPhase implements SearchPhase {
throw new ElasticsearchException("I/O exception during suggest phase", e);
}
}
static class SortedHits {
int[] docs;
}
}

View File

@ -282,7 +282,7 @@ public class AbstractSearchAsyncActionTests extends ESTestCase {
nodeLookups.add(Tuple.tuple(resultClusterAlias, resultNodeId));
phaseResult.setSearchShardTarget(new SearchShardTarget(resultNodeId, resultShardId, resultClusterAlias, OriginalIndices.NONE));
phaseResult.setShardIndex(i);
phaseResults.consumeResult(phaseResult);
phaseResults.consumeResult(phaseResult, () -> {});
}
return phaseResults;
}

View File

@ -35,10 +35,10 @@ import java.util.concurrent.Executor;
public class CountedCollectorTests extends ESTestCase {
public void testCollect() throws InterruptedException {
AtomicArray<SearchPhaseResult> results = new AtomicArray<>(randomIntBetween(1, 100));
ArraySearchPhaseResults<SearchPhaseResult> consumer = new ArraySearchPhaseResults<>(randomIntBetween(1, 100));
List<Integer> state = new ArrayList<>();
int numResultsExpected = randomIntBetween(1, results.length());
MockSearchPhaseContext context = new MockSearchPhaseContext(results.length());
int numResultsExpected = randomIntBetween(1, consumer.getAtomicArray().length());
MockSearchPhaseContext context = new MockSearchPhaseContext(consumer.getAtomicArray().length());
CountDownLatch latch = new CountDownLatch(1);
boolean maybeFork = randomBoolean();
Executor executor = (runnable) -> {
@ -49,7 +49,7 @@ public class CountedCollectorTests extends ESTestCase {
runnable.run();
}
};
CountedCollector<SearchPhaseResult> collector = new CountedCollector<>(r -> results.set(r.getShardIndex(), r), numResultsExpected,
CountedCollector<SearchPhaseResult> collector = new CountedCollector<>(consumer, numResultsExpected,
latch::countDown, context);
for (int i = 0; i < numResultsExpected; i++) {
int shardID = i;
@ -78,7 +78,7 @@ public class CountedCollectorTests extends ESTestCase {
}
latch.await();
assertEquals(numResultsExpected, state.size());
AtomicArray<SearchPhaseResult> results = consumer.getAtomicArray();
for (int i = 0; i < numResultsExpected; i++) {
switch (state.get(i)) {
case 0:

View File

@ -95,7 +95,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
public void run() throws IOException {
responseRef.set(response.results);
}
}, mockSearchPhaseContext);
}, mockSearchPhaseContext, exc -> {});
assertEquals("dfs_query", phase.getName());
phase.run();
mockSearchPhaseContext.assertNoFailure();
@ -149,7 +149,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
public void run() throws IOException {
responseRef.set(response.results);
}
}, mockSearchPhaseContext);
}, mockSearchPhaseContext, exc -> {});
assertEquals("dfs_query", phase.getName());
phase.run();
mockSearchPhaseContext.assertNoFailure();
@ -206,7 +206,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
public void run() throws IOException {
responseRef.set(response.results);
}
}, mockSearchPhaseContext);
}, mockSearchPhaseContext, exc -> {});
assertEquals("dfs_query", phase.getName());
expectThrows(UncheckedIOException.class, phase::run);
assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts

View File

@ -26,6 +26,7 @@ import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit;
@ -52,7 +53,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 1);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {});
boolean hasHits = randomBoolean();
final int numHits;
if (hasHits) {
@ -66,7 +68,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
fetchResult.hits(new SearchHits(new SearchHit[] {new SearchHit(42)}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
QueryFetchSearchResult fetchSearchResult = new QueryFetchSearchResult(queryResult, fetchResult);
fetchSearchResult.setShardIndex(0);
results.consumeResult(fetchSearchResult);
results.consumeResult(fetchSearchResult, () -> {});
numHits = 1;
} else {
numHits = 0;
@ -95,7 +97,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx1,
@ -104,7 +107,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize); // the size of the result set
queryResult.setShardIndex(0);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
final SearchContextId ctx2 = new SearchContextId(UUIDs.randomBase64UUID(), 312);
queryResult = new QuerySearchResult(ctx2,
@ -113,7 +116,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize);
queryResult.setShardIndex(1);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) {
@Override
@ -155,8 +158,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results =
controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx1,
@ -165,7 +168,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize); // the size of the result set
queryResult.setShardIndex(0);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
SearchContextId ctx2 = new SearchContextId(UUIDs.randomBase64UUID(), 321);
queryResult = new QuerySearchResult(ctx2,
@ -174,7 +177,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize);
queryResult.setShardIndex(1);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) {
@Override
@ -220,8 +223,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP,
mockSearchPhaseContext.getRequest(), numHits);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP,
mockSearchPhaseContext.getRequest(), numHits, exc -> {});
for (int i = 0; i < numHits; i++) {
QuerySearchResult queryResult = new QuerySearchResult(new SearchContextId("", i),
new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE));
@ -229,7 +232,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(i+1, i)}), i), new DocValueFormat[0]);
queryResult.size(resultSetSize); // the size of the result set
queryResult.setShardIndex(i);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
}
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) {
@Override
@ -278,7 +281,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results =
controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
QuerySearchResult queryResult = new QuerySearchResult(new SearchContextId("", 123),
new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE));
@ -286,7 +290,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize); // the size of the result set
queryResult.setShardIndex(0);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
queryResult = new QuerySearchResult(new SearchContextId("", 321),
new SearchShardTarget("node2", new ShardId("test", "na", 1), null, OriginalIndices.NONE));
@ -294,7 +298,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize);
queryResult.setShardIndex(1);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
AtomicInteger numFetches = new AtomicInteger(0);
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) {
@Override
@ -334,8 +338,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results =
controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = 1;
SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx1,
@ -344,7 +348,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(42, 1.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize); // the size of the result set
queryResult.setShardIndex(0);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
SearchContextId ctx2 = new SearchContextId(UUIDs.randomBase64UUID(), 321);
queryResult = new QuerySearchResult(ctx2,
@ -353,7 +357,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(84, 2.0F)}), 2.0F), new DocValueFormat[0]);
queryResult.size(resultSetSize);
queryResult.setShardIndex(1);
results.consumeResult(queryResult);
results.consumeResult(queryResult, () -> {});
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) {
@Override

View File

@ -35,12 +35,17 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit;
@ -67,16 +72,21 @@ import org.elasticsearch.search.suggest.phrase.PhraseSuggestion;
import org.elasticsearch.search.suggest.term.TermSuggestion;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
@ -86,20 +96,25 @@ import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static org.elasticsearch.action.search.SearchProgressListener.NOOP;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
public class SearchPhaseControllerTests extends ESTestCase {
private ThreadPool threadPool;
private EsThreadPoolExecutor fixedExecutor;
private SearchPhaseController searchPhaseController;
private List<Boolean> reductions;
@Override
protected NamedWriteableRegistry writableRegistry() {
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables());
List<NamedWriteableRegistry.Entry> entries =
new ArrayList<>(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables());
entries.add(new NamedWriteableRegistry.Entry(InternalAggregation.class, "throwing", InternalThrowing::new));
return new NamedWriteableRegistry(entries);
}
@Before
@ -119,6 +134,15 @@ public class SearchPhaseControllerTests extends ESTestCase {
BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineTree.EMPTY);
};
});
threadPool = new TestThreadPool(SearchPhaseControllerTests.class.getName());
fixedExecutor = EsExecutors.newFixed("test", 1, 10,
EsExecutors.daemonThreadFactory("test"), threadPool.getThreadContext());
}
@After
public void cleanup() {
fixedExecutor.shutdownNow();
terminate(threadPool);
}
public void testSortDocs() {
@ -141,9 +165,15 @@ public class SearchPhaseControllerTests extends ESTestCase {
int suggestionSize = suggestion.getEntries().get(0).getOptions().size();
accumulatedLength += suggestionSize;
}
ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(true, results.asList(), null,
new SearchPhaseController.TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE), from, size,
reducedCompletionSuggestions).scoreDocs;
List<TopDocs> topDocsList = new ArrayList<>();
for (SearchPhaseResult result : results.asList()) {
QuerySearchResult queryResult = result.queryResult();
TopDocs topDocs = queryResult.consumeTopDocs().topDocs;
SearchPhaseController.setShardIndex(topDocs, result.getShardIndex());
topDocsList.add(topDocs);
}
ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(true, topDocsList,
from, size, reducedCompletionSuggestions).scoreDocs;
assertThat(sortedDocs.length, equalTo(accumulatedLength));
}
@ -161,14 +191,26 @@ public class SearchPhaseControllerTests extends ESTestCase {
from = first.get().queryResult().from();
size = first.get().queryResult().size();
}
SearchPhaseController.TopDocsStats topDocsStats = new SearchPhaseController.TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(ignoreFrom, results.asList(), null, topDocsStats, from, size,
List<TopDocs> topDocsList = new ArrayList<>();
for (SearchPhaseResult result : results.asList()) {
QuerySearchResult queryResult = result.queryResult();
TopDocs topDocs = queryResult.consumeTopDocs().topDocs;
topDocsList.add(topDocs);
SearchPhaseController.setShardIndex(topDocs, result.getShardIndex());
}
ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(ignoreFrom, topDocsList, from, size,
Collections.emptyList()).scoreDocs;
results = generateSeededQueryResults(randomSeed, nShards, Collections.emptyList(), queryResultSize,
useConstantScore);
SearchPhaseController.TopDocsStats topDocsStats2 = new SearchPhaseController.TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
ScoreDoc[] sortedDocs2 = SearchPhaseController.sortDocs(ignoreFrom, results.asList(), null, topDocsStats2, from, size,
topDocsList = new ArrayList<>();
for (SearchPhaseResult result : results.asList()) {
QuerySearchResult queryResult = result.queryResult();
TopDocs topDocs = queryResult.consumeTopDocs().topDocs;
topDocsList.add(topDocs);
SearchPhaseController.setShardIndex(topDocs, result.getShardIndex());
}
ScoreDoc[] sortedDocs2 = SearchPhaseController.sortDocs(ignoreFrom, topDocsList, from, size,
Collections.emptyList()).scoreDocs;
assertEquals(sortedDocs.length, sortedDocs2.length);
for (int i = 0; i < sortedDocs.length; i++) {
@ -176,10 +218,6 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertEquals(sortedDocs[i].shardIndex, sortedDocs2[i].shardIndex);
assertEquals(sortedDocs[i].score, sortedDocs2[i].score, 0.0f);
}
assertEquals(topDocsStats.getMaxScore(), topDocsStats2.getMaxScore(), 0.0f);
assertEquals(topDocsStats.getTotalHits().value, topDocsStats2.getTotalHits().value);
assertEquals(topDocsStats.getTotalHits().relation, topDocsStats2.getTotalHits().relation);
assertEquals(topDocsStats.fetchHits, topDocsStats2.fetchHits);
}
private AtomicArray<SearchPhaseResult> generateSeededQueryResults(long seed, int nShards,
@ -200,9 +238,10 @@ public class SearchPhaseControllerTests extends ESTestCase {
int nShards = randomIntBetween(1, 20);
int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2);
AtomicArray<SearchPhaseResult> queryResults = generateQueryResults(nShards, suggestions, queryResultSize, false);
for (int trackTotalHits : new int[] {SearchContext.TRACK_TOTAL_HITS_DISABLED, SearchContext.TRACK_TOTAL_HITS_ACCURATE}) {
SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(
queryResults.asList(), false, trackTotalHits, InternalAggregationTestCase.emptyReduceContextBuilder(), true);
for (int trackTotalHits : new int[] { SearchContext.TRACK_TOTAL_HITS_DISABLED, SearchContext.TRACK_TOTAL_HITS_ACCURATE }) {
SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList(),
new ArrayList<>(), new ArrayList<>(), new SearchPhaseController.TopDocsStats(trackTotalHits),
0, true, InternalAggregationTestCase.emptyReduceContextBuilder(), true);
AtomicArray<SearchPhaseResult> fetchResults = generateFetchResults(nShards,
reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest);
InternalSearchResponse mergedResponse = searchPhaseController.merge(false,
@ -246,9 +285,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
* Generate random query results received from the provided number of shards, including the provided
* number of search hits and randomly generated completion suggestions based on the name and size of the provided ones.
* Note that <code>shardIndex</code> is already set to the generated completion suggestions to simulate what
* {@link SearchPhaseController#reducedQueryPhase(Collection, boolean, int, InternalAggregation.ReduceContextBuilder, boolean)} does,
* meaning that the returned query results can be fed directly to
* {@link SearchPhaseController#sortDocs(boolean, Collection, Collection, SearchPhaseController.TopDocsStats, int, int, List)}
* {@link SearchPhaseController#reducedQueryPhase} does,
* meaning that the returned query results can be fed directly to {@link SearchPhaseController#sortDocs}
*/
private static AtomicArray<SearchPhaseResult> generateQueryResults(int nShards, List<CompletionSuggestion> suggestions,
int searchHitsSize, boolean useConstantScore) {
@ -364,22 +402,24 @@ public class SearchPhaseControllerTests extends ESTestCase {
Strings.EMPTY_ARRAY, "remote", 0, randomBoolean());
}
public void testConsumer() {
public void testConsumer() throws Exception {
consumerTestCase(0);
}
public void testConsumerWithEmptyResponse() {
public void testConsumerWithEmptyResponse() throws Exception {
consumerTestCase(randomIntBetween(1, 5));
}
private void consumerTestCase(int numEmptyResponses) {
private void consumerTestCase(int numEmptyResponses) throws Exception {
long beforeCompletedTasks = fixedExecutor.getCompletedTaskCount();
int numShards = 3 + numEmptyResponses;
int bufferSize = randomIntBetween(2, 3);
CountDownLatch latch = new CountDownLatch(numShards);
SearchRequest request = randomSearchRequest();
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, 3+numEmptyResponses);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, 3+numEmptyResponses, exc -> {});
if (numEmptyResponses == 0) {
assertEquals(0, reductions.size());
}
@ -388,7 +428,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
int shardId = 2 + numEmptyResponses;
empty.setShardIndex(2+numEmptyResponses);
empty.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null, OriginalIndices.NONE));
consumer.consumeResult(empty);
consumer.consumeResult(empty, latch::countDown);
numEmptyResponses --;
}
@ -399,7 +439,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
InternalAggregations aggs = InternalAggregations.from(singletonList(new InternalMax("test", 1.0D, DocValueFormat.RAW, emptyMap())));
result.aggregations(aggs);
result.setShardIndex(0);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), 1),
new SearchShardTarget("node", new ShardId("a", "b", 0), null, OriginalIndices.NONE));
@ -408,7 +448,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
aggs = InternalAggregations.from(singletonList(new InternalMax("test", 3.0D, DocValueFormat.RAW, emptyMap())));
result.aggregations(aggs);
result.setShardIndex(2);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), 1),
new SearchShardTarget("node", new ShardId("a", "b", 0), null, OriginalIndices.NONE));
@ -417,40 +457,38 @@ public class SearchPhaseControllerTests extends ESTestCase {
aggs = InternalAggregations.from(singletonList(new InternalMax("test", 2.0D, DocValueFormat.RAW, emptyMap())));
result.aggregations(aggs);
result.setShardIndex(1);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
while (numEmptyResponses > 0) {
result = QuerySearchResult.nullInstance();
int shardId = 2 + numEmptyResponses;
result.setShardIndex(shardId);
result.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null, OriginalIndices.NONE));
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
numEmptyResponses--;
}
latch.await();
final int numTotalReducePhases;
if (numShards > bufferSize) {
assertThat(consumer, instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class));
if (bufferSize == 2) {
assertEquals(1, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumReducePhases());
assertEquals(2, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumBuffered());
assertEquals(1, ((QueryPhaseResultConsumer) consumer).getNumReducePhases());
assertEquals(1, reductions.size());
assertEquals(false, reductions.get(0));
numTotalReducePhases = 2;
} else {
assertEquals(0, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumReducePhases());
assertEquals(3, ((SearchPhaseController.QueryPhaseResultConsumer) consumer).getNumBuffered());
assertEquals(0, ((QueryPhaseResultConsumer) consumer).getNumReducePhases());
assertEquals(0, reductions.size());
numTotalReducePhases = 1;
}
} else {
assertThat(consumer, not(instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class)));
assertEquals(0, reductions.size());
numTotalReducePhases = 1;
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
int numCompletedTasks = (int) (fixedExecutor.getCompletedTaskCount() - beforeCompletedTasks);
assertEquals(numCompletedTasks, reduce.numReducePhases-1);
assertEquals(numTotalReducePhases, reduce.numReducePhases);
assertEquals(numTotalReducePhases, reductions.size());
assertAggReduction(request);
@ -462,17 +500,18 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertNull(reduce.sortedTopDocs.collapseValues);
}
public void testConsumerConcurrently() throws InterruptedException {
public void testConsumerConcurrently() throws Exception {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
Thread[] threads = new Thread[expectedNumResults];
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
int id = i;
threads[i] = new Thread(() -> {
@ -488,7 +527,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.aggregations(aggs);
result.setShardIndex(id);
result.size(1);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
});
threads[i].start();
@ -496,6 +535,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
for (int i = 0; i < expectedNumResults; i++) {
threads[i].join();
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertAggReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
@ -510,15 +551,16 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertNull(reduce.sortedTopDocs.collapseValues);
}
public void testConsumerOnlyAggs() {
public void testConsumerOnlyAggs() throws Exception {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
int number = randomIntBetween(1, 1000);
max.updateAndGet(prev -> Math.max(prev, number));
@ -531,8 +573,10 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.aggregations(aggs);
result.setShardIndex(i);
result.size(1);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertAggReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
@ -546,7 +590,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertNull(reduce.sortedTopDocs.collapseValues);
}
public void testConsumerOnlyHits() {
public void testConsumerOnlyHits() throws Exception {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
@ -554,9 +598,10 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().size(randomIntBetween(1, 10)));
}
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
int number = randomIntBetween(1, 1000);
max.updateAndGet(prev -> Math.max(prev, number));
@ -566,8 +611,9 @@ public class SearchPhaseControllerTests extends ESTestCase {
new ScoreDoc[] {new ScoreDoc(0, number)}), number), new DocValueFormat[0]);
result.setShardIndex(i);
result.size(1);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertAggReduction(request);
assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
@ -591,47 +637,14 @@ public class SearchPhaseControllerTests extends ESTestCase {
}
}
public void testNewSearchPhaseResults() {
for (int i = 0; i < 10; i++) {
int expectedNumResults = randomIntBetween(1, 10);
int bufferSize = randomIntBetween(2, 10);
SearchRequest request = new SearchRequest();
final boolean hasAggs;
if ((hasAggs = randomBoolean())) {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
}
final boolean hasTopDocs;
if ((hasTopDocs = randomBoolean())) {
if (request.source() != null) {
request.source().size(randomIntBetween(1, 100));
} // no source means size = 10
} else {
if (request.source() == null) {
request.source(new SearchSourceBuilder().size(0));
} else {
request.source().size(0);
}
}
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer
= searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
if ((hasAggs || hasTopDocs) && expectedNumResults > bufferSize) {
assertThat("expectedNumResults: " + expectedNumResults + " bufferSize: " + bufferSize,
consumer, instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class));
} else {
assertThat("expectedNumResults: " + expectedNumResults + " bufferSize: " + bufferSize,
consumer, not(instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class)));
}
}
}
public void testReduceTopNWithFromOffset() {
public void testReduceTopNWithFromOffset() throws Exception {
SearchRequest request = new SearchRequest();
request.source(new SearchSourceBuilder().size(5).from(5));
request.setBatchedReduceSize(randomIntBetween(2, 4));
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, 4);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, 4, exc -> {});
int score = 100;
CountDownLatch latch = new CountDownLatch(4);
for (int i = 0; i < 4; i++) {
QuerySearchResult result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), i),
new SearchShardTarget("node", new ShardId("a", "b", i), null, OriginalIndices.NONE));
@ -644,8 +657,9 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.setShardIndex(i);
result.size(5);
result.from(5);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
}
latch.await();
// 4*3 results = 12 we get result 5 to 10 here with from=5 and size=5
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
ScoreDoc[] scoreDocs = reduce.sortedTopDocs.scoreDocs;
@ -659,17 +673,18 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertEquals(91.0f, scoreDocs[4].score, 0.0f);
}
public void testConsumerSortByField() {
public void testConsumerSortByField() throws Exception {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
int size = randomIntBetween(1, 10);
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)};
DocValueFormat[] docValueFormats = {DocValueFormat.RAW};
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
int number = randomIntBetween(1, 1000);
max.updateAndGet(prev -> Math.max(prev, number));
@ -680,8 +695,9 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats);
result.setShardIndex(i);
result.size(size);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertAggReduction(request);
assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length);
@ -695,20 +711,21 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertNull(reduce.sortedTopDocs.collapseValues);
}
public void testConsumerFieldCollapsing() {
public void testConsumerFieldCollapsing() throws Exception {
int expectedNumResults = randomIntBetween(30, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
int size = randomIntBetween(5, 10);
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
SortField[] sortFields = {new SortField("field", SortField.Type.STRING)};
BytesRef a = new BytesRef("a");
BytesRef b = new BytesRef("b");
BytesRef c = new BytesRef("c");
Object[] collapseValues = new Object[]{a, b, c};
DocValueFormat[] docValueFormats = {DocValueFormat.RAW};
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
Object[] values = {randomFrom(collapseValues)};
FieldDoc[] fieldDocs = {new FieldDoc(0, Float.NaN, values)};
@ -718,8 +735,9 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats);
result.setShardIndex(i);
result.size(size);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertAggReduction(request);
assertEquals(3, reduce.sortedTopDocs.scoreDocs.length);
@ -735,16 +753,17 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertArrayEquals(collapseValues, reduce.sortedTopDocs.collapseValues);
}
public void testConsumerSuggestions() {
public void testConsumerSuggestions() throws Exception {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(NOOP, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
int maxScoreTerm = -1;
int maxScorePhrase = -1;
int maxScoreCompletion = -1;
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
QuerySearchResult result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), i),
new SearchShardTarget("node", new ShardId("a", "b", i), null, OriginalIndices.NONE));
@ -792,8 +811,9 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.topDocs(new TopDocsAndMaxScore(Lucene.EMPTY_TOP_DOCS, Float.NaN), new DocValueFormat[0]);
result.setShardIndex(i);
result.size(0);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertEquals(3, reduce.suggest.size());
{
@ -827,7 +847,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertNull(reduce.sortedTopDocs.collapseValues);
}
public void testProgressListener() throws InterruptedException {
public void testProgressListener() throws Exception {
int expectedNumResults = randomIntBetween(10, 100);
for (int bufferSize : new int[] {expectedNumResults, expectedNumResults/2, expectedNumResults/4, 2}) {
SearchRequest request = randomSearchRequest();
@ -861,13 +881,14 @@ public class SearchPhaseControllerTests extends ESTestCase {
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
totalHitsListener.set(totalHits);
finalAggsListener.set(aggs);
numReduceListener.incrementAndGet();
assertEquals(numReduceListener.incrementAndGet(), reducePhase);
}
};
ArraySearchPhaseResults<SearchPhaseResult> consumer =
searchPhaseController.newSearchPhaseResults(progressListener, request, expectedNumResults);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
progressListener, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
Thread[] threads = new Thread[expectedNumResults];
CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) {
int id = i;
threads[i] = new Thread(() -> {
@ -883,13 +904,14 @@ public class SearchPhaseControllerTests extends ESTestCase {
result.aggregations(aggs);
result.setShardIndex(id);
result.size(1);
consumer.consumeResult(result);
consumer.consumeResult(result, latch::countDown);
});
threads[i].start();
}
for (int i = 0; i < expectedNumResults; i++) {
threads[i].join();
}
latch.await();
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertAggReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
@ -912,4 +934,94 @@ public class SearchPhaseControllerTests extends ESTestCase {
}
}
public void testPartialMergeFailure() throws InterruptedException {
int expectedNumResults = randomIntBetween(20, 200);
int bufferSize = randomIntBetween(2, expectedNumResults - 1);
SearchRequest request = new SearchRequest();
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
AtomicBoolean hasConsumedFailure = new AtomicBoolean();
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true));
CountDownLatch latch = new CountDownLatch(expectedNumResults);
Thread[] threads = new Thread[expectedNumResults];
int failedIndex = randomIntBetween(0, expectedNumResults-1);
for (int i = 0; i < expectedNumResults; i++) {
final int index = i;
threads[index] = new Thread(() -> {
QuerySearchResult result = new QuerySearchResult(new SearchContextId(UUIDs.randomBase64UUID(), index),
new SearchShardTarget("node", new ShardId("a", "b", index), null, OriginalIndices.NONE));
result.topDocs(new TopDocsAndMaxScore(
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN),
new DocValueFormat[0]);
InternalAggregations aggs = InternalAggregations.from(
Collections.singletonList(new InternalThrowing("test", (failedIndex == index), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(index);
result.size(1);
consumer.consumeResult(result, latch::countDown);
});
threads[index].start();
}
for (int i = 0; i < expectedNumResults; i++) {
threads[i].join();
}
latch.await();
IllegalStateException exc = expectThrows(IllegalStateException.class, () -> consumer.reduce());
if (exc.getMessage().contains("partial reduce")) {
assertTrue(hasConsumedFailure.get());
} else {
assertThat(exc.getMessage(), containsString("final reduce"));
}
}
private static class InternalThrowing extends InternalAggregation {
private final boolean shouldThrow;
protected InternalThrowing(String name, boolean shouldThrow, Map<String, Object> metadata) {
super(name, metadata);
this.shouldThrow = shouldThrow;
}
protected InternalThrowing(StreamInput in) throws IOException {
super(in);
this.shouldThrow = in.readBoolean();
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeBoolean(shouldThrow);
}
@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
if (aggregations.stream()
.map(agg -> (InternalThrowing) agg)
.anyMatch(agg -> agg.shouldThrow)) {
if (reduceContext.isFinalReduce()) {
throw new IllegalStateException("final reduce");
} else {
throw new IllegalStateException("partial reduce");
}
}
return new InternalThrowing(name, false, metadata);
}
@Override
public Object getProperty(List<String> path) {
return null;
}
@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("not implemented");
}
@Override
public String getWriteableName() {
return "throwing";
}
}
}

View File

@ -57,15 +57,15 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
public class SearchQueryThenFetchAsyncActionTests extends ESTestCase {
public void testBottomFieldSort() throws InterruptedException {
public void testBottomFieldSort() throws Exception {
testCase(false);
}
public void testScrollDisableBottomFieldSort() throws InterruptedException {
public void testScrollDisableBottomFieldSort() throws Exception {
testCase(true);
}
private void testCase(boolean withScroll) throws InterruptedException {
private void testCase(boolean withScroll) throws Exception {
final TransportSearchAction.SearchTimeProvider timeProvider =
new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(), System::nanoTime);
@ -131,7 +131,7 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase {
Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
Collections.emptyMap(), Collections.emptyMap(), controller, EsExecutors.newDirectExecutorService(), searchRequest,
null, shardsIter, timeProvider, null, task,
SearchResponse.Clusters.EMPTY) {
SearchResponse.Clusters.EMPTY, exc -> {}) {
@Override
protected SearchPhase getNextPhase(SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return new SearchPhase("test") {

View File

@ -1610,7 +1610,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
SearchPhaseController searchPhaseController = new SearchPhaseController(
writableRegistry(), searchService::aggReduceContextBuilder);
actions.put(SearchAction.INSTANCE,
new TransportSearchAction(threadPool, transportService, searchService,
new TransportSearchAction(client, threadPool, transportService, searchService,
searchTransportService, searchPhaseController, clusterService,
actionFilters, indexNameExpressionResolver));
actions.put(RestoreSnapshotAction.INSTANCE,