Cleanup usages of QueryPhaseResultConsumer (#61713)

This commit generalizes how QueryPhaseResultConsumer is initialized.
The query phase always uses this consumer so it doesn't need to be hidden behind
an abstract class.
This commit is contained in:
Jim Ferenczi 2020-09-02 13:42:40 +02:00 committed by jimczi
parent a8bbdd937e
commit a0e4331c49
6 changed files with 47 additions and 54 deletions

View File

@ -41,7 +41,7 @@ import java.util.function.Function;
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
private final ArraySearchPhaseResults<SearchPhaseResult> queryResult;
private final QueryPhaseResultConsumer queryResult;
private final SearchPhaseController searchPhaseController;
private final AtomicArray<DfsSearchResult> dfsSearchResults;
private final Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;

View File

@ -31,6 +31,7 @@ import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.query.QuerySearchResult;
import java.util.ArrayDeque;
@ -43,6 +44,7 @@ import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize;
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
import static org.elasticsearch.action.search.SearchPhaseController.setShardIndex;
@ -52,7 +54,7 @@ import static org.elasticsearch.action.search.SearchPhaseController.setShardInde
* This implementation can be configured to batch up a certain amount of results and reduce
* them asynchronously in the provided {@link Executor} iff the buffer is exhausted.
*/
class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class);
private final Executor executor;
@ -76,35 +78,31 @@ class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
* as shard results are consumed.
*/
QueryPhaseResultConsumer(Executor executor,
SearchPhaseController controller,
SearchProgressListener progressListener,
ReduceContextBuilder aggReduceContextBuilder,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
int bufferSize,
boolean hasTopDocs,
boolean hasAggs,
int trackTotalHitsUpTo,
int topNSize,
boolean performFinalReduce,
Consumer<Exception> onPartialMergeFailure) {
public QueryPhaseResultConsumer(SearchRequest request,
Executor executor,
SearchPhaseController controller,
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure) {
super(expectedResultSize);
this.executor = executor;
this.controller = controller;
this.progressListener = progressListener;
this.aggReduceContextBuilder = aggReduceContextBuilder;
this.aggReduceContextBuilder = controller.getReduceContext(request);
this.namedWriteableRegistry = namedWriteableRegistry;
this.topNSize = topNSize;
this.pendingMerges = new PendingMerges(bufferSize, trackTotalHitsUpTo);
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
this.performFinalReduce = performFinalReduce;
this.topNSize = getTopDocsSize(request);
this.performFinalReduce = request.isFinalReduce();
this.onPartialMergeFailure = onPartialMergeFailure;
SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0;
this.hasAggs = source != null && source.aggregations() != null;
int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(bufferSize, request.resolveTrackTotalHitsUpTo());
}
@Override
void consumeResult(SearchPhaseResult result, Runnable next) {
public void consumeResult(SearchPhaseResult result, Runnable next) {
super.consumeResult(result, () -> {});
QuerySearchResult querySearchResult = result.queryResult();
progressListener.notifyQueryResult(querySearchResult.getShardIndex());
@ -112,7 +110,7 @@ class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult
}
@Override
SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
if (pendingMerges.hasPendingMerges()) {
throw new AssertionError("partial reduce in-flight");
} else if (pendingMerges.hasFailure()) {

View File

@ -558,23 +558,19 @@ public final class SearchPhaseController {
}
}
InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request) {
return requestToAggReduceContextBuilder.apply(request);
}
/**
* Returns a new ArraySearchPhaseResults instance. This might return an instance that reduces search responses incrementally.
* Returns a new {@link QueryPhaseResultConsumer} instance. This might return an instance that reduces search responses incrementally.
*/
ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(Executor executor,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure) {
SearchSourceBuilder source = request.source();
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final int trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = requestToAggReduceContextBuilder.apply(request);
int topNSize = getTopDocsSize(request);
int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), numShards) : numShards;
return new QueryPhaseResultConsumer(executor, this, listener, aggReduceContextBuilder, namedWriteableRegistry,
numShards, bufferSize, hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, request.isFinalReduce(), onPartialMergeFailure);
QueryPhaseResultConsumer newSearchPhaseResults(Executor executor,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure) {
return new QueryPhaseResultConsumer(request, executor, this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure);
}
static final class TopDocsStats {

View File

@ -39,7 +39,7 @@ import java.util.stream.StreamSupport;
/**
* A listener that allows to track progress of the {@link SearchAction}.
*/
abstract class SearchProgressListener {
public abstract class SearchProgressListener {
private static final Logger logger = LogManager.getLogger(SearchProgressListener.class);
public static final SearchProgressListener NOOP = new SearchProgressListener() {};

View File

@ -31,7 +31,6 @@ import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.QueryFetchSearchResult;
@ -53,7 +52,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {});
boolean hasHits = randomBoolean();
final int numHits;
@ -97,7 +96,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
@ -158,7 +157,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
@ -223,7 +222,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP,
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP,
mockSearchPhaseContext.getRequest(), numHits, exc -> {});
for (int i = 0; i < numHits; i++) {
QuerySearchResult queryResult = new QuerySearchResult(new SearchContextId("", i),
@ -280,7 +279,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results =
QueryPhaseResultConsumer results =
controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10);
@ -338,7 +337,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = 1;
SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);

View File

@ -555,7 +555,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
SearchRequest request = randomSearchRequest();
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults);
@ -596,7 +596,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().size(randomIntBetween(1, 10)));
}
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults);
@ -639,7 +639,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
SearchRequest request = new SearchRequest();
request.source(new SearchSourceBuilder().size(5).from(5));
request.setBatchedReduceSize(randomIntBetween(2, 4));
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, 4, exc -> {});
int score = 100;
CountDownLatch latch = new CountDownLatch(4);
@ -677,7 +677,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
SearchRequest request = randomSearchRequest();
int size = randomIntBetween(1, 10);
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)};
@ -715,7 +715,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
SearchRequest request = randomSearchRequest();
int size = randomIntBetween(5, 10);
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
SortField[] sortFields = {new SortField("field", SortField.Type.STRING)};
BytesRef a = new BytesRef("a");
@ -756,7 +756,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = randomSearchRequest();
request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {});
int maxScoreTerm = -1;
int maxScorePhrase = -1;
@ -882,7 +882,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
assertEquals(numReduceListener.incrementAndGet(), reducePhase);
}
};
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
progressListener, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger();
Thread[] threads = new Thread[expectedNumResults];
@ -940,7 +940,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
AtomicBoolean hasConsumedFailure = new AtomicBoolean();
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true));
CountDownLatch latch = new CountDownLatch(expectedNumResults);
Thread[] threads = new Thread[expectedNumResults];