Save memory in on aggs in async search (#55683) (#55879)

This replaces a reference to the result of partially reducing
aggregations that async search keeps with a reference to the serialized
form of the result of the partial reduction which we need to keep
anyway.
This commit is contained in:
Nik Everett 2020-04-28 16:23:30 -04:00 committed by GitHub
parent fed296ebb7
commit a5d0409a8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 109 additions and 74 deletions

View File

@ -697,15 +697,16 @@ public final class SearchPhaseController {
private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) {
if (index == bufferSize) {
InternalAggregations reducedAggs = null;
DelayableWriteable.Serialized<InternalAggregations> reducedAggs = null;
if (hasAggs) {
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].get());
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
}
reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction());
aggsBuffer[0] = DelayableWriteable.referencing(reducedAggs)
InternalAggregations reduced =
InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction());
reducedAggs = aggsBuffer[0] = DelayableWriteable.referencing(reduced)
.asSerialized(InternalAggregations::new, namedWriteableRegistry);
long previousBufferSize = aggsCurrentBufferSize;
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);

View File

@ -25,6 +25,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchResponse.Clusters;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations;
@ -78,10 +79,11 @@ abstract class SearchProgressListener {
*
* @param shards The list of shards that are part of this reduce.
* @param totalHits The total number of hits in this reduce.
* @param aggs The partial result for aggregations.
* @param aggs The partial result for aggregations stored in serialized form.
* @param reducePhase The version number for this reduce.
*/
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {}
/**
* Executed once when the final reduce is created.
@ -136,7 +138,8 @@ abstract class SearchProgressListener {
}
}
final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
try {
onPartialReduce(shards, totalHits, aggs, reducePhase);
} catch (Exception e) {

View File

@ -33,6 +33,7 @@ import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
@ -851,7 +852,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
}
@Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
assertEquals(numReduceListener.incrementAndGet(), reducePhase);
}

View File

@ -23,6 +23,7 @@ import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.AggregationBuilders;
@ -173,7 +174,8 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase {
}
@Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
numReduces.incrementAndGet();
}

View File

@ -17,13 +17,12 @@ import org.elasticsearch.action.search.SearchShard;
import org.elasticsearch.action.search.SearchTask;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool;
@ -41,6 +40,8 @@ import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static java.util.Collections.singletonList;
/**
* Task that tracks the progress of a currently running {@link SearchRequest}.
*/
@ -362,32 +363,48 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
// best effort to cancel expired tasks
checkCancellation();
searchResponse.compareAndSet(null,
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters,
aggReduceContextSupplier, threadPool.getThreadContext()));
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, threadPool.getThreadContext()));
executeInitListeners();
}
@Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggregations, int reducePhase) {
// best effort to cancel expired tasks
checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), aggs == null);
// The way that the MutableSearchResponse will build the aggs.
Supplier<InternalAggregations> reducedAggs;
if (aggregations == null) {
// There aren't any aggs to reduce.
reducedAggs = () -> null;
} else {
/*
* Keep a reference to the serialiazed form of the partially
* reduced aggs and reduce it on the fly when someone asks
* for it. This will produce right-ish aggs. Much more right
* than if you don't do the final reduce. Its important that
* we wait until someone needs the result so we don't perform
* the final reduce only to throw it away. And it is important
* that we kep the reference to the serialized aggrgations
* because the SearchPhaseController *already* has that
* reference so we're not creating more garbage.
*/
reducedAggs = () ->
InternalAggregations.topLevelReduce(singletonList(aggregations.get()), aggReduceContextSupplier.get());
}
searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase);
}
@Override
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggregations, int reducePhase) {
// best effort to cancel expired tasks
checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), true);
searchResponse.get().updatePartialResponse(shards.size(), totalHits, () -> aggregations, reducePhase);
}
@Override
public void onResponse(SearchResponse response) {
searchResponse.get().updateFinalResponse(response.getSuccessfulShards(), response.getInternalResponse());
searchResponse.get().updateFinalResponse(response);
executeCompletionListeners();
}
@ -396,8 +413,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
if (searchResponse.get() == null) {
// if the failure occurred before calling onListShards
searchResponse.compareAndSet(null,
new MutableSearchResponse(-1, -1, null,
aggReduceContextSupplier, threadPool.getThreadContext()));
new MutableSearchResponse(-1, -1, null, threadPool.getThreadContext()));
}
searchResponse.get().updateWithFailure(exc);
executeInitListeners();

View File

@ -9,13 +9,11 @@ import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchResponse.Clusters;
import org.elasticsearch.action.search.SearchResponseSections;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
@ -25,9 +23,6 @@ import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import static java.util.Collections.singletonList;
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
import static org.elasticsearch.search.aggregations.InternalAggregations.topLevelReduce;
import static org.elasticsearch.xpack.core.async.AsyncTaskIndexService.restoreResponseHeadersContext;
/**
@ -41,13 +36,24 @@ class MutableSearchResponse {
private final int skippedShards;
private final Clusters clusters;
private final AtomicArray<ShardSearchFailure> shardFailures;
private final Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier;
private final ThreadContext threadContext;
private boolean isPartial;
private boolean isFinalReduce;
private int successfulShards;
private SearchResponseSections sections;
private TotalHits totalHits;
/**
* How we get the reduced aggs when {@link #finalResponse} isn't populated.
* We default to returning no aggs, this {@code -> null}. We'll replace
* this as we receive updates on the search progress listener.
*/
private Supplier<InternalAggregations> reducedAggsSource = () -> null;
private int reducePhase;
/**
* The response produced by the search API. Once we receive it we stop
* building our own {@linkplain SearchResponse}s when you get the status
* and instead return this.
*/
private SearchResponse finalResponse;
private ElasticsearchException failure;
private Map<String, List<String>> responseHeaders;
@ -59,56 +65,49 @@ class MutableSearchResponse {
* @param totalShards The number of shards that participate in the request, or -1 to indicate a failure.
* @param skippedShards The number of skipped shards, or -1 to indicate a failure.
* @param clusters The remote clusters statistics.
* @param aggReduceContextSupplier A supplier to run final reduce on partial aggregations.
* @param threadContext The thread context to retrieve the final response headers.
*/
MutableSearchResponse(int totalShards,
int skippedShards,
Clusters clusters,
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier,
ThreadContext threadContext) {
this.totalShards = totalShards;
this.skippedShards = skippedShards;
this.clusters = clusters;
this.aggReduceContextSupplier = aggReduceContextSupplier;
this.shardFailures = totalShards == -1 ? null : new AtomicArray<>(totalShards-skippedShards);
this.isPartial = true;
this.threadContext = threadContext;
this.sections = totalShards == -1 ? null : new InternalSearchResponse(
new SearchHits(SearchHits.EMPTY, new TotalHits(0, GREATER_THAN_OR_EQUAL_TO), Float.NaN),
null, null, null, false, null, 0);
}
/**
* Updates the response with the partial {@link SearchResponseSections} merged from #<code>successfulShards</code>
* shards.
* Updates the response with the result of a partial reduction.
* @param reducedAggs is a strategy for producing the reduced aggs
*/
synchronized void updatePartialResponse(int successfulShards, SearchResponseSections newSections, boolean isFinalReduce) {
synchronized void updatePartialResponse(int successfulShards, TotalHits totalHits,
Supplier<InternalAggregations> reducedAggs, int reducePhase) {
failIfFrozen();
if (newSections.getNumReducePhases() < sections.getNumReducePhases()) {
if (reducePhase < this.reducePhase) {
// should never happen since partial response are updated under a lock
// in the search phase controller
throw new IllegalStateException("received partial response out of order: "
+ newSections.getNumReducePhases() + " < " + sections.getNumReducePhases());
+ reducePhase + " < " + this.reducePhase);
}
this.successfulShards = successfulShards;
this.sections = newSections;
this.isPartial = true;
this.isFinalReduce = isFinalReduce;
this.totalHits = totalHits;
this.reducedAggsSource = reducedAggs;
this.reducePhase = reducePhase;
}
/**
* Updates the response with the final {@link SearchResponseSections} merged from #<code>successfulShards</code>
* shards.
* Updates the response with the final {@link SearchResponse} once the
* search is complete.
*/
synchronized void updateFinalResponse(int successfulShards, SearchResponseSections newSections) {
synchronized void updateFinalResponse(SearchResponse response) {
failIfFrozen();
// copy the response headers from the current context
this.responseHeaders = threadContext.getResponseHeaders();
this.successfulShards = successfulShards;
this.sections = newSections;
this.finalResponse = response;
this.isPartial = false;
this.isFinalReduce = true;
this.frozen = true;
}
@ -141,23 +140,34 @@ class MutableSearchResponse {
* This method is synchronized to ensure that we don't perform final reduces concurrently.
*/
synchronized AsyncSearchResponse toAsyncSearchResponse(AsyncSearchTask task, long expirationTime) {
final SearchResponse resp;
if (totalShards != -1) {
if (sections.aggregations() != null && isFinalReduce == false) {
InternalAggregations oldAggs = (InternalAggregations) sections.aggregations();
InternalAggregations newAggs = topLevelReduce(singletonList(oldAggs), aggReduceContextSupplier.get());
sections = new InternalSearchResponse(sections.hits(), newAggs, sections.suggest(),
null, sections.timedOut(), sections.terminatedEarly(), sections.getNumReducePhases());
isFinalReduce = true;
return new AsyncSearchResponse(task.getExecutionId().getEncoded(), findOrBuildResponse(task),
failure, isPartial, frozen == false, task.getStartTime(), expirationTime);
}
private SearchResponse findOrBuildResponse(AsyncSearchTask task) {
if (finalResponse != null) {
// We have a final response, use it.
return finalResponse;
}
if (clusters == null) {
// An error occurred before we got the shard list
return null;
}
/*
* Build the response, reducing aggs if we haven't already and
* storing the result of the reduction so we won't have to reduce
* a second time if you get the response again and nothing has
* changed. This does cost memory because we have a reference
* to the reduced aggs sitting around so it can't be GCed until
* we get an update.
*/
InternalAggregations reducedAggs = reducedAggsSource.get();
reducedAggsSource = () -> reducedAggs;
InternalSearchResponse internal = new InternalSearchResponse(
new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), reducedAggs, null, null, false, false, reducePhase);
long tookInMillis = TimeValue.timeValueNanos(System.nanoTime() - task.getStartTimeNanos()).getMillis();
resp = new SearchResponse(sections, null, totalShards, successfulShards,
skippedShards, tookInMillis, buildShardFailures(), clusters);
} else {
resp = null;
}
return new AsyncSearchResponse(task.getExecutionId().getEncoded(), resp, failure, isPartial,
frozen == false, task.getStartTime(), expirationTime);
return new SearchResponse(internal, null, totalShards, successfulShards, skippedShards,
tookInMillis, buildShardFailures(), clusters);
}
/**

View File

@ -173,24 +173,25 @@ public class AsyncSearchTaskTests extends ESTestCase {
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
int numFetchFailures = randomIntBetween(0, numShards);
ShardSearchFailure[] failures = new ShardSearchFailure[numFetchFailures];
for (int i = 0; i < numFetchFailures; i++) {
task.getSearchProgressActionListener().onFetchFailure(i,
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
new IOException("boum"));
failures[i] = new ShardSearchFailure(new IOException("boum"),
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE));
task.getSearchProgressActionListener().onFetchFailure(i, failures[i].shard(), (Exception) failures[i].getCause());
}
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards));
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards, failures));
assertCompletionListeners(task, numShards+numSkippedShards,
numSkippedShards, numFetchFailures, false);
}
private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) {
private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards,
ShardSearchFailure... failures) {
InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(),
InternalAggregations.EMPTY, null, null, false, null, 1);
return new SearchResponse(response, null, totalShards, successfulShards, skippedShards,
100, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);
100, failures, SearchResponse.Clusters.EMPTY);
}
private void assertCompletionListeners(AsyncSearchTask task,