From a5d0409a8f47434699529a74c952387b52932617 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Tue, 28 Apr 2020 16:23:30 -0400 Subject: [PATCH] Save memory in on aggs in async search (#55683) (#55879) This replaces a reference to the result of partially reducing aggregations that async search keeps with a reference to the serialized form of the result of the partial reduction which we need to keep anyway. --- .../action/search/SearchPhaseController.java | 7 +- .../action/search/SearchProgressListener.java | 9 +- .../search/SearchPhaseControllerTests.java | 4 +- .../SearchProgressActionListenerIT.java | 4 +- .../xpack/search/AsyncSearchTask.java | 46 ++++++--- .../xpack/search/MutableSearchResponse.java | 98 ++++++++++--------- .../xpack/search/AsyncSearchTaskTests.java | 15 +-- 7 files changed, 109 insertions(+), 74 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 0bb7bd75341..d5b65f887b5 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -697,15 +697,16 @@ public final class SearchPhaseController { private synchronized void consumeInternal(QuerySearchResult querySearchResult) { if (querySearchResult.isNull() == false) { if (index == bufferSize) { - InternalAggregations reducedAggs = null; + DelayableWriteable.Serialized reducedAggs = null; if (hasAggs) { List aggs = new ArrayList<>(aggsBuffer.length); for (int i = 0; i < aggsBuffer.length; i++) { aggs.add(aggsBuffer[i].get()); aggsBuffer[i] = null; // null the buffer so it can be GCed now. } - reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction()); - aggsBuffer[0] = DelayableWriteable.referencing(reducedAggs) + InternalAggregations reduced = + InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction()); + reducedAggs = aggsBuffer[0] = DelayableWriteable.referencing(reduced) .asSerialized(InternalAggregations::new, namedWriteableRegistry); long previousBufferSize = aggsCurrentBufferSize; aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index 7d6b435fe4d..1a3596bd725 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -25,6 +25,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.search.SearchResponse.Clusters; import org.elasticsearch.cluster.routing.GroupShardsIterator; +import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -78,10 +79,11 @@ abstract class SearchProgressListener { * * @param shards The list of shards that are part of this reduce. * @param totalHits The total number of hits in this reduce. - * @param aggs The partial result for aggregations. + * @param aggs The partial result for aggregations stored in serialized form. * @param reducePhase The version number for this reduce. */ - protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} + protected void onPartialReduce(List shards, TotalHits totalHits, + DelayableWriteable.Serialized aggs, int reducePhase) {} /** * Executed once when the final reduce is created. @@ -136,7 +138,8 @@ abstract class SearchProgressListener { } } - final void notifyPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + final void notifyPartialReduce(List shards, TotalHits totalHits, + DelayableWriteable.Serialized aggs, int reducePhase) { try { onPartialReduce(shards, totalHits, aggs, reducePhase); } catch (Exception e) { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index f9c5baf4b61..e8abe21bfc2 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -33,6 +33,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; @@ -851,7 +852,8 @@ public class SearchPhaseControllerTests extends ESTestCase { } @Override - public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onPartialReduce(List shards, TotalHits totalHits, + DelayableWriteable.Serialized aggs, int reducePhase) { assertEquals(numReduceListener.incrementAndGet(), reducePhase); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java b/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java index fa78487843b..85d21c22ab7 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java @@ -23,6 +23,7 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.elasticsearch.client.Client; import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.AggregationBuilders; @@ -173,7 +174,8 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase { } @Override - public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onPartialReduce(List shards, TotalHits totalHits, + DelayableWriteable.Serialized aggs, int reducePhase) { numReduces.incrementAndGet(); } diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 57d2a072976..e7ce86a0dad 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -17,13 +17,12 @@ import org.elasticsearch.action.search.SearchShard; import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; +import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; -import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; -import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.Scheduler.Cancellable; import org.elasticsearch.threadpool.ThreadPool; @@ -41,6 +40,8 @@ import java.util.function.BooleanSupplier; import java.util.function.Consumer; import java.util.function.Supplier; +import static java.util.Collections.singletonList; + /** * Task that tracks the progress of a currently running {@link SearchRequest}. */ @@ -362,32 +363,48 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask { // best effort to cancel expired tasks checkCancellation(); searchResponse.compareAndSet(null, - new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, - aggReduceContextSupplier, threadPool.getThreadContext())); + new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, threadPool.getThreadContext())); executeInitListeners(); } @Override - public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onPartialReduce(List shards, TotalHits totalHits, + DelayableWriteable.Serialized aggregations, int reducePhase) { // best effort to cancel expired tasks checkCancellation(); - searchResponse.get().updatePartialResponse(shards.size(), - new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs, - null, null, false, null, reducePhase), aggs == null); + // The way that the MutableSearchResponse will build the aggs. + Supplier reducedAggs; + if (aggregations == null) { + // There aren't any aggs to reduce. + reducedAggs = () -> null; + } else { + /* + * Keep a reference to the serialiazed form of the partially + * reduced aggs and reduce it on the fly when someone asks + * for it. This will produce right-ish aggs. Much more right + * than if you don't do the final reduce. Its important that + * we wait until someone needs the result so we don't perform + * the final reduce only to throw it away. And it is important + * that we kep the reference to the serialized aggrgations + * because the SearchPhaseController *already* has that + * reference so we're not creating more garbage. + */ + reducedAggs = () -> + InternalAggregations.topLevelReduce(singletonList(aggregations.get()), aggReduceContextSupplier.get()); + } + searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase); } @Override - public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggregations, int reducePhase) { // best effort to cancel expired tasks checkCancellation(); - searchResponse.get().updatePartialResponse(shards.size(), - new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs, - null, null, false, null, reducePhase), true); + searchResponse.get().updatePartialResponse(shards.size(), totalHits, () -> aggregations, reducePhase); } @Override public void onResponse(SearchResponse response) { - searchResponse.get().updateFinalResponse(response.getSuccessfulShards(), response.getInternalResponse()); + searchResponse.get().updateFinalResponse(response); executeCompletionListeners(); } @@ -396,8 +413,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask { if (searchResponse.get() == null) { // if the failure occurred before calling onListShards searchResponse.compareAndSet(null, - new MutableSearchResponse(-1, -1, null, - aggReduceContextSupplier, threadPool.getThreadContext())); + new MutableSearchResponse(-1, -1, null, threadPool.getThreadContext())); } searchResponse.get().updateWithFailure(exc); executeInitListeners(); diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java index 985e22b5e7f..a3f4a48fdf1 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java @@ -9,13 +9,11 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse.Clusters; -import org.elasticsearch.action.search.SearchResponseSections; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.search.SearchHits; -import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse; @@ -25,9 +23,6 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; -import static java.util.Collections.singletonList; -import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; -import static org.elasticsearch.search.aggregations.InternalAggregations.topLevelReduce; import static org.elasticsearch.xpack.core.async.AsyncTaskIndexService.restoreResponseHeadersContext; /** @@ -41,13 +36,24 @@ class MutableSearchResponse { private final int skippedShards; private final Clusters clusters; private final AtomicArray shardFailures; - private final Supplier aggReduceContextSupplier; private final ThreadContext threadContext; private boolean isPartial; - private boolean isFinalReduce; private int successfulShards; - private SearchResponseSections sections; + private TotalHits totalHits; + /** + * How we get the reduced aggs when {@link #finalResponse} isn't populated. + * We default to returning no aggs, this {@code -> null}. We'll replace + * this as we receive updates on the search progress listener. + */ + private Supplier reducedAggsSource = () -> null; + private int reducePhase; + /** + * The response produced by the search API. Once we receive it we stop + * building our own {@linkplain SearchResponse}s when you get the status + * and instead return this. + */ + private SearchResponse finalResponse; private ElasticsearchException failure; private Map> responseHeaders; @@ -59,56 +65,49 @@ class MutableSearchResponse { * @param totalShards The number of shards that participate in the request, or -1 to indicate a failure. * @param skippedShards The number of skipped shards, or -1 to indicate a failure. * @param clusters The remote clusters statistics. - * @param aggReduceContextSupplier A supplier to run final reduce on partial aggregations. * @param threadContext The thread context to retrieve the final response headers. */ MutableSearchResponse(int totalShards, int skippedShards, Clusters clusters, - Supplier aggReduceContextSupplier, ThreadContext threadContext) { this.totalShards = totalShards; this.skippedShards = skippedShards; this.clusters = clusters; - this.aggReduceContextSupplier = aggReduceContextSupplier; this.shardFailures = totalShards == -1 ? null : new AtomicArray<>(totalShards-skippedShards); this.isPartial = true; this.threadContext = threadContext; - this.sections = totalShards == -1 ? null : new InternalSearchResponse( - new SearchHits(SearchHits.EMPTY, new TotalHits(0, GREATER_THAN_OR_EQUAL_TO), Float.NaN), - null, null, null, false, null, 0); } /** - * Updates the response with the partial {@link SearchResponseSections} merged from #successfulShards - * shards. + * Updates the response with the result of a partial reduction. + * @param reducedAggs is a strategy for producing the reduced aggs */ - synchronized void updatePartialResponse(int successfulShards, SearchResponseSections newSections, boolean isFinalReduce) { + synchronized void updatePartialResponse(int successfulShards, TotalHits totalHits, + Supplier reducedAggs, int reducePhase) { failIfFrozen(); - if (newSections.getNumReducePhases() < sections.getNumReducePhases()) { + if (reducePhase < this.reducePhase) { // should never happen since partial response are updated under a lock // in the search phase controller throw new IllegalStateException("received partial response out of order: " - + newSections.getNumReducePhases() + " < " + sections.getNumReducePhases()); + + reducePhase + " < " + this.reducePhase); } this.successfulShards = successfulShards; - this.sections = newSections; - this.isPartial = true; - this.isFinalReduce = isFinalReduce; + this.totalHits = totalHits; + this.reducedAggsSource = reducedAggs; + this.reducePhase = reducePhase; } /** - * Updates the response with the final {@link SearchResponseSections} merged from #successfulShards - * shards. + * Updates the response with the final {@link SearchResponse} once the + * search is complete. */ - synchronized void updateFinalResponse(int successfulShards, SearchResponseSections newSections) { + synchronized void updateFinalResponse(SearchResponse response) { failIfFrozen(); // copy the response headers from the current context this.responseHeaders = threadContext.getResponseHeaders(); - this.successfulShards = successfulShards; - this.sections = newSections; + this.finalResponse = response; this.isPartial = false; - this.isFinalReduce = true; this.frozen = true; } @@ -141,23 +140,34 @@ class MutableSearchResponse { * This method is synchronized to ensure that we don't perform final reduces concurrently. */ synchronized AsyncSearchResponse toAsyncSearchResponse(AsyncSearchTask task, long expirationTime) { - final SearchResponse resp; - if (totalShards != -1) { - if (sections.aggregations() != null && isFinalReduce == false) { - InternalAggregations oldAggs = (InternalAggregations) sections.aggregations(); - InternalAggregations newAggs = topLevelReduce(singletonList(oldAggs), aggReduceContextSupplier.get()); - sections = new InternalSearchResponse(sections.hits(), newAggs, sections.suggest(), - null, sections.timedOut(), sections.terminatedEarly(), sections.getNumReducePhases()); - isFinalReduce = true; - } - long tookInMillis = TimeValue.timeValueNanos(System.nanoTime() - task.getStartTimeNanos()).getMillis(); - resp = new SearchResponse(sections, null, totalShards, successfulShards, - skippedShards, tookInMillis, buildShardFailures(), clusters); - } else { - resp = null; + return new AsyncSearchResponse(task.getExecutionId().getEncoded(), findOrBuildResponse(task), + failure, isPartial, frozen == false, task.getStartTime(), expirationTime); + } + + private SearchResponse findOrBuildResponse(AsyncSearchTask task) { + if (finalResponse != null) { + // We have a final response, use it. + return finalResponse; } - return new AsyncSearchResponse(task.getExecutionId().getEncoded(), resp, failure, isPartial, - frozen == false, task.getStartTime(), expirationTime); + if (clusters == null) { + // An error occurred before we got the shard list + return null; + } + /* + * Build the response, reducing aggs if we haven't already and + * storing the result of the reduction so we won't have to reduce + * a second time if you get the response again and nothing has + * changed. This does cost memory because we have a reference + * to the reduced aggs sitting around so it can't be GCed until + * we get an update. + */ + InternalAggregations reducedAggs = reducedAggsSource.get(); + reducedAggsSource = () -> reducedAggs; + InternalSearchResponse internal = new InternalSearchResponse( + new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), reducedAggs, null, null, false, false, reducePhase); + long tookInMillis = TimeValue.timeValueNanos(System.nanoTime() - task.getStartTimeNanos()).getMillis(); + return new SearchResponse(internal, null, totalShards, successfulShards, skippedShards, + tookInMillis, buildShardFailures(), clusters); } /** diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java index 7ab8f5f6104..8cdc497a1e5 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java @@ -173,24 +173,25 @@ public class AsyncSearchTaskTests extends ESTestCase { task.getSearchProgressActionListener().onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); int numFetchFailures = randomIntBetween(0, numShards); + ShardSearchFailure[] failures = new ShardSearchFailure[numFetchFailures]; for (int i = 0; i < numFetchFailures; i++) { - task.getSearchProgressActionListener().onFetchFailure(i, - new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE), - new IOException("boum")); - + failures[i] = new ShardSearchFailure(new IOException("boum"), + new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE)); + task.getSearchProgressActionListener().onFetchFailure(i, failures[i].shard(), (Exception) failures[i].getCause()); } assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true); ((AsyncSearchTask.Listener)task.getProgressListener()).onResponse( - newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards)); + newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards, failures)); assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, false); } - private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) { + private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards, + ShardSearchFailure... failures) { InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(), InternalAggregations.EMPTY, null, null, false, null, 1); return new SearchResponse(response, null, totalShards, successfulShards, skippedShards, - 100, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY); + 100, failures, SearchResponse.Clusters.EMPTY); } private void assertCompletionListeners(AsyncSearchTask task,