Reduce memory for big aggs run against many shards (#54758) (#55024)

This changes the behavior of aggregations when search is performed
against enough shards to enable "batch reduce" mode. In this case we
force always store aggregations in serialized form rather than a
traditional java reference. This should shrink the memory usage of large
aggregations at the cost of slightly slowing down aggregations where the
coordinating node is also a data node. Because we're only doing this
when there are many shards this is likely to be fairly rare.

As a side effect this lets us add logs for the memory usage of the aggs
buffer:
```
[2020-04-03T17:03:57,052][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1320->448] max [1320]
[2020-04-03T17:03:57,089][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1328->448] max [1328]
[2020-04-03T17:03:57,102][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1328->448] max [1328]
[2020-04-03T17:03:57,103][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1328->448] max [1328]
[2020-04-03T17:03:57,105][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs final reduction [888] max [1328]
```

These are useful, but you need to keep some things in mind before
trusting them:
1. The buffers are oversized ala Lucene's ArrayUtils. This means that we
   are using more space than we need, but probably not much more.
2. Before they are merged the aggregations are inflated into their
   traditional Java objects which *probably* take up a lot more space
   than the serialized form. That is, after all, the reason why we store
   them in serialized form in the first place.

And, just because I can, here is another example of the log:
```
[2020-04-03T17:06:18,731][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528]
[2020-04-03T17:06:18,750][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528]
[2020-04-03T17:06:18,809][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528]
[2020-04-03T17:06:18,827][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528]
[2020-04-03T17:06:18,829][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs final reduction [98352] max [147528]
```

I got that last one by building a ten shard index with a million docs in
it and running a `sum` in three layers of `terms` aggregations, all on
`long` fields, and with a `batched_reduce_size` of `3`.
This commit is contained in:
Nik Everett 2020-04-09 14:58:42 -04:00 committed by GitHub
parent 850ea7c0be
commit 62d6bc31bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 155 additions and 68 deletions

View File

@ -19,18 +19,11 @@
package org.elasticsearch.action.search; package org.elasticsearch.action.search;
import java.util.ArrayList; import com.carrotsearch.hppc.IntArrayList;
import java.util.Arrays; import com.carrotsearch.hppc.ObjectObjectHashMap;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
@ -44,6 +37,8 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.TotalHits.Relation;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs; import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.elasticsearch.common.collect.HppcMaps; import org.elasticsearch.common.collect.HppcMaps;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -67,16 +62,28 @@ import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.Suggest.Suggestion; import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion; import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
import com.carrotsearch.hppc.IntArrayList; import java.util.ArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
public final class SearchPhaseController { public final class SearchPhaseController {
private static final Logger logger = LogManager.getLogger(SearchPhaseController.class);
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0]; private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
private final NamedWriteableRegistry namedWriteableRegistry;
private final Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder; private final Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;
public SearchPhaseController( public SearchPhaseController(NamedWriteableRegistry namedWriteableRegistry,
Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder) { Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder) {
this.namedWriteableRegistry = namedWriteableRegistry;
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder; this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
} }
@ -430,7 +437,8 @@ public final class SearchPhaseController {
* @see QuerySearchResult#consumeProfileResult() * @see QuerySearchResult#consumeProfileResult()
*/ */
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults, private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<Supplier<InternalAggregations>> bufferedAggs, List<TopDocs> bufferedTopDocs, List<Supplier<InternalAggregations>> bufferedAggs,
List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest, TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) { boolean performFinalReduce) {
@ -522,7 +530,7 @@ public final class SearchPhaseController {
private InternalAggregations reduceAggs( private InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce, boolean performFinalReduce,
List<Supplier<InternalAggregations>> aggregationsList List<? extends Supplier<InternalAggregations>> aggregationsList
) { ) {
/* /*
* Parse the aggregations, clearing the list as we go so bits backing * Parse the aggregations, clearing the list as we go so bits backing
@ -617,8 +625,9 @@ public final class SearchPhaseController {
* iff the buffer is exhausted. * iff the buffer is exhausted.
*/ */
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> { static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private final NamedWriteableRegistry namedWriteableRegistry;
private final SearchShardTarget[] processedShards; private final SearchShardTarget[] processedShards;
private final Supplier<InternalAggregations>[] aggsBuffer; private final DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer;
private final TopDocs[] topDocsBuffer; private final TopDocs[] topDocsBuffer;
private final boolean hasAggs; private final boolean hasAggs;
private final boolean hasTopDocs; private final boolean hasTopDocs;
@ -631,6 +640,8 @@ public final class SearchPhaseController {
private final int topNSize; private final int topNSize;
private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder; private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
private final boolean performFinalReduce; private final boolean performFinalReduce;
private long aggsCurrentBufferSize;
private long aggsMaxBufferSize;
/** /**
* Creates a new {@link QueryPhaseResultConsumer} * Creates a new {@link QueryPhaseResultConsumer}
@ -641,12 +652,14 @@ public final class SearchPhaseController {
* @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results * @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results
* the buffer is used to incrementally reduce aggregation results before all shards responded. * the buffer is used to incrementally reduce aggregation results before all shards responded.
*/ */
private QueryPhaseResultConsumer(SearchProgressListener progressListener, SearchPhaseController controller, private QueryPhaseResultConsumer(NamedWriteableRegistry namedWriteableRegistry, SearchProgressListener progressListener,
SearchPhaseController controller,
int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs, int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
int trackTotalHitsUpTo, int topNSize, int trackTotalHitsUpTo, int topNSize,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) { boolean performFinalReduce) {
super(expectedResultSize); super(expectedResultSize);
this.namedWriteableRegistry = namedWriteableRegistry;
if (expectedResultSize != 1 && bufferSize < 2) { if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result"); throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
} }
@ -661,7 +674,7 @@ public final class SearchPhaseController {
this.processedShards = new SearchShardTarget[expectedResultSize]; this.processedShards = new SearchShardTarget[expectedResultSize];
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time. // no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Supplier<InternalAggregations>[] aggsBuffer = new Supplier[hasAggs ? bufferSize : 0]; DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer = new DelayableWriteable.Serialized[hasAggs ? bufferSize : 0];
this.aggsBuffer = aggsBuffer; this.aggsBuffer = aggsBuffer;
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0]; this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
this.hasTopDocs = hasTopDocs; this.hasTopDocs = hasTopDocs;
@ -684,15 +697,21 @@ public final class SearchPhaseController {
private synchronized void consumeInternal(QuerySearchResult querySearchResult) { private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) { if (querySearchResult.isNull() == false) {
if (index == bufferSize) { if (index == bufferSize) {
InternalAggregations reducedAggs = null;
if (hasAggs) { if (hasAggs) {
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length); List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) { for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].get()); aggs.add(aggsBuffer[i].get());
aggsBuffer[i] = null; // null the buffer so it can be GCed now. aggsBuffer[i] = null; // null the buffer so it can be GCed now.
} }
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce( reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction());
aggs, aggReduceContextBuilder.forPartialReduction()); aggsBuffer[0] = DelayableWriteable.referencing(reducedAggs)
aggsBuffer[0] = () -> reducedAggs; .asSerialized(InternalAggregations::new, namedWriteableRegistry);
long previousBufferSize = aggsCurrentBufferSize;
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
aggsCurrentBufferSize = aggsBuffer[0].ramBytesUsed();
logger.trace("aggs partial reduction [{}->{}] max [{}]",
previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize);
} }
if (hasTopDocs) { if (hasTopDocs) {
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer), TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
@ -705,12 +724,13 @@ public final class SearchPhaseController {
index = 1; index = 1;
if (hasAggs || hasTopDocs) { if (hasAggs || hasTopDocs) {
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards), progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0].get() : null, numReducePhases); topDocsStats.getTotalHits(), reducedAggs, numReducePhases);
} }
} }
final int i = index++; final int i = index++;
if (hasAggs) { if (hasAggs) {
aggsBuffer[i] = querySearchResult.consumeAggs(); aggsBuffer[i] = querySearchResult.consumeAggs().asSerialized(InternalAggregations::new, namedWriteableRegistry);
aggsCurrentBufferSize += aggsBuffer[i].ramBytesUsed();
} }
if (hasTopDocs) { if (hasTopDocs) {
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
@ -723,7 +743,7 @@ public final class SearchPhaseController {
} }
private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() { private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList(aggsBuffer).subList(0, index) : null; return hasAggs ? Arrays.asList((Supplier<InternalAggregations>[]) aggsBuffer).subList(0, index) : null;
} }
private synchronized List<TopDocs> getRemainingTopDocs() { private synchronized List<TopDocs> getRemainingTopDocs() {
@ -732,6 +752,8 @@ public final class SearchPhaseController {
@Override @Override
public ReducedQueryPhase reduce() { public ReducedQueryPhase reduce() {
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize);
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false,
aggReduceContextBuilder, performFinalReduce); aggReduceContextBuilder, performFinalReduce);
@ -767,8 +789,8 @@ public final class SearchPhaseController {
if (request.getBatchedReduceSize() < numShards) { if (request.getBatchedReduceSize() < numShards) {
int topNSize = getTopDocsSize(request); int topNSize = getTopDocsSize(request);
// only use this if there are aggs and if there are more shards than we should reduce at once // only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(listener, this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs, return new QueryPhaseResultConsumer(namedWriteableRegistry, listener, this, numShards, request.getBatchedReduceSize(),
trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce()); hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
} }
} }
return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) { return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {

View File

@ -19,12 +19,14 @@
package org.elasticsearch.common.io.stream; package org.elasticsearch.common.io.stream;
import java.io.IOException; import org.apache.lucene.util.Accountable;
import java.util.function.Supplier; import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import java.io.IOException;
import java.util.function.Supplier;
/** /**
* A holder for {@link Writeable}s that can delays reading the underlying * A holder for {@link Writeable}s that can delays reading the underlying
* {@linkplain Writeable} when it is read from a remote node. * {@linkplain Writeable} when it is read from a remote node.
@ -43,12 +45,22 @@ public abstract class DelayableWriteable<T extends Writeable> implements Supplie
* when {@link Supplier#get()} is called. * when {@link Supplier#get()} is called.
*/ */
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException { public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
return new Delayed<>(reader, in); return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readBytesReference());
} }
private DelayableWriteable() {} private DelayableWriteable() {}
public abstract boolean isDelayed(); /**
* Returns a {@linkplain DelayableWriteable} that stores its contents
* in serialized form.
*/
public abstract Serialized<T> asSerialized(Writeable.Reader<T> reader, NamedWriteableRegistry registry);
/**
* {@code true} if the {@linkplain Writeable} is being stored in
* serialized form, {@code false} otherwise.
*/
abstract boolean isSerialized();
private static class Referencing<T extends Writeable> extends DelayableWriteable<T> { private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
private T reference; private T reference;
@ -59,11 +71,7 @@ public abstract class DelayableWriteable<T extends Writeable> implements Supplie
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
try (BytesStreamOutput buffer = new BytesStreamOutput()) { out.writeBytesReference(writeToBuffer(out.getVersion()).bytes());
buffer.setVersion(out.getVersion());
reference.writeTo(buffer);
out.writeBytesReference(buffer.bytes());
}
} }
@Override @Override
@ -72,27 +80,48 @@ public abstract class DelayableWriteable<T extends Writeable> implements Supplie
} }
@Override @Override
public boolean isDelayed() { public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
try {
return new Serialized<T>(reader, Version.CURRENT, registry, writeToBuffer(Version.CURRENT).bytes());
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding aggregations", e);
}
}
@Override
boolean isSerialized() {
return false; return false;
} }
private BytesStreamOutput writeToBuffer(Version version) throws IOException {
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
buffer.setVersion(version);
reference.writeTo(buffer);
return buffer;
}
}
} }
private static class Delayed<T extends Writeable> extends DelayableWriteable<T> { /**
* A {@link Writeable} stored in serialized form.
*/
public static class Serialized<T extends Writeable> extends DelayableWriteable<T> implements Accountable {
private final Writeable.Reader<T> reader; private final Writeable.Reader<T> reader;
private final Version remoteVersion; private final Version serializedAtVersion;
private final BytesReference serialized;
private final NamedWriteableRegistry registry; private final NamedWriteableRegistry registry;
private final BytesReference serialized;
Delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException { Serialized(Writeable.Reader<T> reader, Version serializedAtVersion,
NamedWriteableRegistry registry, BytesReference serialized) throws IOException {
this.reader = reader; this.reader = reader;
remoteVersion = in.getVersion(); this.serializedAtVersion = serializedAtVersion;
serialized = in.readBytesReference(); this.registry = registry;
registry = in.namedWriteableRegistry(); this.serialized = serialized;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion() == remoteVersion) { if (out.getVersion() == serializedAtVersion) {
/* /*
* If the version *does* line up we can just copy the bytes * If the version *does* line up we can just copy the bytes
* which is good because this is how shard request caching * which is good because this is how shard request caching
@ -116,7 +145,7 @@ public abstract class DelayableWriteable<T extends Writeable> implements Supplie
try { try {
try (StreamInput in = registry == null ? try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) { serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(remoteVersion); in.setVersion(serializedAtVersion);
return reader.read(in); return reader.read(in);
} }
} catch (IOException e) { } catch (IOException e) {
@ -125,8 +154,18 @@ public abstract class DelayableWriteable<T extends Writeable> implements Supplie
} }
@Override @Override
public boolean isDelayed() { public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
return this; // We're already serialized
}
@Override
boolean isSerialized() {
return true; return true;
} }
@Override
public long ramBytesUsed() {
return serialized.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 3 + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
}
} }
} }

View File

@ -585,7 +585,8 @@ public class Node implements Closeable {
b.bind(MetadataCreateIndexService.class).toInstance(metadataCreateIndexService); b.bind(MetadataCreateIndexService.class).toInstance(metadataCreateIndexService);
b.bind(SearchService.class).toInstance(searchService); b.bind(SearchService.class).toInstance(searchService);
b.bind(SearchTransportService.class).toInstance(searchTransportService); b.bind(SearchTransportService.class).toInstance(searchTransportService);
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::aggReduceContextBuilder)); b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(
namedWriteableRegistry, searchService::aggReduceContextBuilder));
b.bind(Transport.class).toInstance(transport); b.bind(Transport.class).toInstance(transport);
b.bind(TransportService.class).toInstance(transportService); b.bind(TransportService.class).toInstance(transportService);
b.bind(NetworkService.class).toInstance(networkService); b.bind(NetworkService.class).toInstance(networkService);

View File

@ -213,6 +213,6 @@ public class DfsQueryPhaseTests extends ESTestCase {
} }
private SearchPhaseController searchPhaseController() { private SearchPhaseController searchPhaseController() {
return new SearchPhaseController(request -> InternalAggregationTestCase.emptyReduceContextBuilder()); return new SearchPhaseController(writableRegistry(), request -> InternalAggregationTestCase.emptyReduceContextBuilder());
} }
} }

View File

@ -49,7 +49,8 @@ import static org.elasticsearch.action.search.SearchProgressListener.NOOP;
public class FetchSearchPhaseTests extends ESTestCase { public class FetchSearchPhaseTests extends ESTestCase {
public void testShortcutQueryAndFetchOptimization() { public void testShortcutQueryAndFetchOptimization() {
SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 1); ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 1);
boolean hasHits = randomBoolean(); boolean hasHits = randomBoolean();
@ -92,7 +93,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
public void testFetchTwoDocument() { public void testFetchTwoDocument() {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
int resultSetSize = randomIntBetween(2, 10); int resultSetSize = randomIntBetween(2, 10);
final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123); final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
@ -151,7 +153,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
public void testFailFetchOneDoc() { public void testFailFetchOneDoc() {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = ArraySearchPhaseResults<SearchPhaseResult> results =
controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
int resultSetSize = randomIntBetween(2, 10); int resultSetSize = randomIntBetween(2, 10);
@ -214,7 +217,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
int resultSetSize = randomIntBetween(0, 100); int resultSetSize = randomIntBetween(0, 100);
// we use at least 2 hits otherwise this is subject to single shard optimization and we trip an assert... // we use at least 2 hits otherwise this is subject to single shard optimization and we trip an assert...
int numHits = randomIntBetween(2, 100); // also numshards --> 1 hit per shard int numHits = randomIntBetween(2, 100); // also numshards --> 1 hit per shard
SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits);
ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP,
mockSearchPhaseContext.getRequest(), numHits); mockSearchPhaseContext.getRequest(), numHits);
@ -271,7 +275,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
public void testExceptionFailsPhase() { public void testExceptionFailsPhase() {
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = ArraySearchPhaseResults<SearchPhaseResult> results =
controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
int resultSetSize = randomIntBetween(2, 10); int resultSetSize = randomIntBetween(2, 10);
@ -327,7 +332,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
public void testCleanupIrrelevantContexts() { // contexts that are not fetched should be cleaned up public void testCleanupIrrelevantContexts() { // contexts that are not fetched should be cleaned up
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
ArraySearchPhaseResults<SearchPhaseResult> results = ArraySearchPhaseResults<SearchPhaseResult> results =
controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2); controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
int resultSetSize = 1; int resultSetSize = 1;

View File

@ -33,8 +33,10 @@ import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
@ -42,6 +44,7 @@ import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregationBuilders;
@ -78,6 +81,7 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.elasticsearch.action.search.SearchProgressListener.NOOP; import static org.elasticsearch.action.search.SearchProgressListener.NOOP;
@ -92,10 +96,15 @@ public class SearchPhaseControllerTests extends ESTestCase {
private SearchPhaseController searchPhaseController; private SearchPhaseController searchPhaseController;
private List<Boolean> reductions; private List<Boolean> reductions;
@Override
protected NamedWriteableRegistry writableRegistry() {
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables());
}
@Before @Before
public void setup() { public void setup() {
reductions = new CopyOnWriteArrayList<>(); reductions = new CopyOnWriteArrayList<>();
searchPhaseController = new SearchPhaseController(s -> new InternalAggregation.ReduceContextBuilder() { searchPhaseController = new SearchPhaseController(writableRegistry(), s -> new InternalAggregation.ReduceContextBuilder() {
@Override @Override
public ReduceContext forPartialReduction() { public ReduceContext forPartialReduction() {
reductions.add(false); reductions.add(false);

View File

@ -123,7 +123,8 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase {
searchRequest.source().trackTotalHitsUpTo(2); searchRequest.source().trackTotalHitsUpTo(2);
} }
searchRequest.allowPartialSearchResults(false); searchRequest.allowPartialSearchResults(false);
SearchPhaseController controller = new SearchPhaseController(r -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder());
SearchTask task = new SearchTask(0, "n/a", "n/a", "test", null, Collections.emptyMap()); SearchTask task = new SearchTask(0, "n/a", "n/a", "test", null, Collections.emptyMap());
SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction(logger, SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction(logger,
searchTransportService, (clusterAlias, node) -> lookup.get(node), searchTransportService, (clusterAlias, node) -> lookup.get(node),

View File

@ -117,42 +117,42 @@ public class DelayableWriteableTests extends ESTestCase {
public void testRoundTripFromReferencing() throws IOException { public void testRoundTripFromReferencing() throws IOException {
Example e = new Example(randomAlphaOfLength(5)); Example e = new Example(randomAlphaOfLength(5));
DelayableWriteable<Example> original = DelayableWriteable.referencing(e); DelayableWriteable<Example> original = DelayableWriteable.referencing(e);
assertFalse(original.isDelayed()); assertFalse(original.isSerialized());
roundTripTestCase(original, Example::new); roundTripTestCase(original, Example::new);
} }
public void testRoundTripFromReferencingWithNamedWriteable() throws IOException { public void testRoundTripFromReferencingWithNamedWriteable() throws IOException {
NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5))); NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
DelayableWriteable<NamedHolder> original = DelayableWriteable.referencing(n); DelayableWriteable<NamedHolder> original = DelayableWriteable.referencing(n);
assertFalse(original.isDelayed()); assertFalse(original.isSerialized());
roundTripTestCase(original, NamedHolder::new); roundTripTestCase(original, NamedHolder::new);
} }
public void testRoundTripFromDelayed() throws IOException { public void testRoundTripFromDelayed() throws IOException {
Example e = new Example(randomAlphaOfLength(5)); Example e = new Example(randomAlphaOfLength(5));
DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, Version.CURRENT); DelayableWriteable<Example> original = DelayableWriteable.referencing(e).asSerialized(Example::new, writableRegistry());
assertTrue(original.isDelayed()); assertTrue(original.isSerialized());
roundTripTestCase(original, Example::new); roundTripTestCase(original, Example::new);
} }
public void testRoundTripFromDelayedWithNamedWriteable() throws IOException { public void testRoundTripFromDelayedWithNamedWriteable() throws IOException {
NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5))); NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, Version.CURRENT); DelayableWriteable<NamedHolder> original = DelayableWriteable.referencing(n).asSerialized(NamedHolder::new, writableRegistry());
assertTrue(original.isDelayed()); assertTrue(original.isSerialized());
roundTripTestCase(original, NamedHolder::new); roundTripTestCase(original, NamedHolder::new);
} }
public void testRoundTripFromDelayedFromOldVersion() throws IOException { public void testRoundTripFromDelayedFromOldVersion() throws IOException {
Example e = new Example(randomAlphaOfLength(5)); Example e = new Example(randomAlphaOfLength(5));
DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion()); DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion());
assertTrue(original.isDelayed()); assertTrue(original.isSerialized());
roundTripTestCase(original, Example::new); roundTripTestCase(original, Example::new);
} }
public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException { public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException {
NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5))); NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion()); DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion());
assertTrue(original.isDelayed()); assertTrue(original.isSerialized());
roundTripTestCase(original, NamedHolder::new); roundTripTestCase(original, NamedHolder::new);
} }
@ -162,9 +162,16 @@ public class DelayableWriteableTests extends ESTestCase {
assertThat(roundTrip(original, SneakOtherSideVersionOnWire::new, remoteVersion).get().version, equalTo(remoteVersion)); assertThat(roundTrip(original, SneakOtherSideVersionOnWire::new, remoteVersion).get().version, equalTo(remoteVersion));
} }
public void testAsSerializedIsNoopOnSerialized() throws IOException {
Example e = new Example(randomAlphaOfLength(5));
DelayableWriteable<Example> d = DelayableWriteable.referencing(e).asSerialized(Example::new, writableRegistry());
assertTrue(d.isSerialized());
assertSame(d, d.asSerialized(Example::new, writableRegistry()));
}
private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException { private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException {
DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT); DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT);
assertTrue(roundTripped.isDelayed()); assertTrue(roundTripped.isSerialized());
assertThat(roundTripped.get(), equalTo(original.get())); assertThat(roundTripped.get(), equalTo(original.get()));
} }

View File

@ -1363,9 +1363,11 @@ public class SnapshotResiliencyTests extends ESTestCase {
SearchExecutionStatsCollector.makeWrapper(responseCollectorService)); SearchExecutionStatsCollector.makeWrapper(responseCollectorService));
final SearchService searchService = new SearchService(clusterService, indicesService, threadPool, scriptService, final SearchService searchService = new SearchService(clusterService, indicesService, threadPool, scriptService,
bigArrays, new FetchPhase(Collections.emptyList()), responseCollectorService, new NoneCircuitBreakerService()); bigArrays, new FetchPhase(Collections.emptyList()), responseCollectorService, new NoneCircuitBreakerService());
SearchPhaseController searchPhaseController = new SearchPhaseController(
writableRegistry(), searchService::aggReduceContextBuilder);
actions.put(SearchAction.INSTANCE, actions.put(SearchAction.INSTANCE,
new TransportSearchAction(threadPool, transportService, searchService, new TransportSearchAction(threadPool, transportService, searchService,
searchTransportService, new SearchPhaseController(searchService::aggReduceContextBuilder), clusterService, searchTransportService, searchPhaseController, clusterService,
actionFilters, indexNameExpressionResolver)); actionFilters, indexNameExpressionResolver));
actions.put(RestoreSnapshotAction.INSTANCE, actions.put(RestoreSnapshotAction.INSTANCE,
new TransportRestoreSnapshotAction(transportService, clusterService, threadPool, restoreService, actionFilters, new TransportRestoreSnapshotAction(transportService, clusterService, threadPool, restoreService, actionFilters,