Request-level circuit breaker support on coordinating nodes (#62884)

This commit allows coordinating node to account the memory used to perform partial and final reduce of
aggregations in the request circuit breaker. The search coordinator adds the memory that it used to save
and reduce the results of shard aggregations in the request circuit breaker. Before any partial or final
reduce, the memory needed to reduce the aggregations is estimated and a CircuitBreakingException} is thrown
if exceeds the maximum memory allowed in this breaker.
This size is estimated as roughly 1.5 times the size of the serialized aggregations that need to be reduced.
This estimation can be completely off for some aggregations but it is corrected with the real size after
the reduce completes.
If the reduce is successful, we update the circuit breaker to remove the size of the source aggregations
and replace the estimation with the serialized size of the newly reduced result.

As a follow up we could trigger partial reduces based on the memory accounted in the circuit breaker instead
of relying on a static number of shard responses. A simpler follow up that could be done in the mean time is
to [reduce the default batch reduce size](https://github.com/elastic/elasticsearch/issues/51857) of blocking
search request to a more sane number.

Closes #37182
This commit is contained in:
Jim Ferenczi 2020-09-24 18:59:28 +02:00 committed by GitHub
parent cd584d49dc
commit 78a93dc18f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1219 additions and 484 deletions

View File

@ -0,0 +1,230 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.benchmark.search.aggregations;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.QueryPhaseResultConsumer;
import org.elasticsearch.action.search.SearchPhaseController;
import org.elasticsearch.action.search.SearchProgressListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.query.QuerySearchResult;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static java.util.Collections.emptyList;
@Warmup(iterations = 5)
@Measurement(iterations = 7)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Fork(value = 1)
public class TermsReduceBenchmark {
private final SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList());
private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables());
private final SearchPhaseController controller = new SearchPhaseController(
namedWriteableRegistry,
req -> new InternalAggregation.ReduceContextBuilder() {
@Override
public InternalAggregation.ReduceContext forPartialReduction() {
return InternalAggregation.ReduceContext.forPartialReduction(null, null, () -> PipelineAggregator.PipelineTree.EMPTY);
}
@Override
public InternalAggregation.ReduceContext forFinalReduction() {
final MultiBucketConsumerService.MultiBucketConsumer bucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer(
Integer.MAX_VALUE,
new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST)
);
return InternalAggregation.ReduceContext.forFinalReduction(
null,
null,
bucketConsumer,
PipelineAggregator.PipelineTree.EMPTY
);
}
}
);
@State(Scope.Benchmark)
public static class TermsList extends AbstractList<InternalAggregations> {
@Param({ "1600172297" })
long seed;
@Param({ "64", "128", "512" })
int numShards;
@Param({ "100" })
int topNSize;
@Param({ "1", "10", "100" })
int cardinalityFactor;
List<InternalAggregations> aggsList;
@Setup
public void setup() {
this.aggsList = new ArrayList<>();
Random rand = new Random(seed);
int cardinality = cardinalityFactor * topNSize;
BytesRef[] dict = new BytesRef[cardinality];
for (int i = 0; i < dict.length; i++) {
dict[i] = new BytesRef(Long.toString(rand.nextLong()));
}
for (int i = 0; i < numShards; i++) {
aggsList.add(InternalAggregations.from(Collections.singletonList(newTerms(rand, dict, true))));
}
}
private StringTerms newTerms(Random rand, BytesRef[] dict, boolean withNested) {
Set<BytesRef> randomTerms = new HashSet<>();
for (int i = 0; i < topNSize; i++) {
randomTerms.add(dict[rand.nextInt(dict.length)]);
}
List<StringTerms.Bucket> buckets = new ArrayList<>();
for (BytesRef term : randomTerms) {
InternalAggregations subAggs;
if (withNested) {
subAggs = InternalAggregations.from(Collections.singletonList(newTerms(rand, dict, false)));
} else {
subAggs = InternalAggregations.EMPTY;
}
buckets.add(new StringTerms.Bucket(term, rand.nextInt(10000), subAggs, true, 0L, DocValueFormat.RAW));
}
Collections.sort(buckets, (a, b) -> a.compareKey(b));
return new StringTerms(
"terms",
BucketOrder.key(true),
BucketOrder.count(false),
topNSize,
1,
Collections.emptyMap(),
DocValueFormat.RAW,
numShards,
true,
0,
buckets,
0
);
}
@Override
public InternalAggregations get(int index) {
return aggsList.get(index);
}
@Override
public int size() {
return aggsList.size();
}
}
@Param({ "32", "512" })
private int bufferSize;
@Benchmark
public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateList) throws Exception {
List<QuerySearchResult> shards = new ArrayList<>();
for (int i = 0; i < candidateList.size(); i++) {
QuerySearchResult result = new QuerySearchResult();
result.setShardIndex(i);
result.from(0);
result.size(0);
result.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1000, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]),
Float.NaN
),
new DocValueFormat[] { DocValueFormat.RAW }
);
result.aggregations(candidateList.get(i));
result.setSearchShardTarget(
new SearchShardTarget("node", new ShardId(new Index("index", "index"), i), null, OriginalIndices.NONE)
);
shards.add(result);
}
SearchRequest request = new SearchRequest();
request.source(new SearchSourceBuilder().size(0).aggregation(AggregationBuilders.terms("test")));
request.setBatchedReduceSize(bufferSize);
ExecutorService executor = Executors.newFixedThreadPool(1);
QueryPhaseResultConsumer consumer = new QueryPhaseResultConsumer(
request,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
controller,
SearchProgressListener.NOOP,
namedWriteableRegistry,
shards.size(),
exc -> {}
);
CountDownLatch latch = new CountDownLatch(shards.size());
for (int i = 0; i < shards.size(); i++) {
consumer.consumeResult(shards.get(i), () -> latch.countDown());
}
latch.await();
SearchPhaseController.ReducedQueryPhase phase = consumer.reduce();
executor.shutdownNow();
return phase;
}
}

View File

@ -23,7 +23,6 @@ import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregationBuilders;
@ -174,8 +173,7 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase {
} }
@Override @Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
numReduces.incrementAndGet(); numReduces.incrementAndGet();
} }

View File

@ -19,22 +19,251 @@
package org.elasticsearch.action.search; package org.elasticsearch.action.search;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorBase;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMax;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class TransportSearchIT extends ESIntegTestCase { public class TransportSearchIT extends ESIntegTestCase {
public static class TestPlugin extends Plugin implements SearchPlugin {
@Override
public List<AggregationSpec> getAggregations() {
return Collections.singletonList(
new AggregationSpec(TestAggregationBuilder.NAME, TestAggregationBuilder::new, TestAggregationBuilder.PARSER)
.addResultReader(InternalMax::new)
);
}
@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
/**
* Set up a fetch sub phase that throws an exception on indices whose name that start with "boom".
*/
return Collections.singletonList(fetchContext -> new FetchSubPhaseProcessor() {
@Override
public void setNextReader(LeafReaderContext readerContext) {
}
@Override
public void process(FetchSubPhase.HitContext hitContext) {
if (fetchContext.getIndexName().startsWith("boom")) {
throw new RuntimeException("boom");
}
}
});
}
}
@Override
protected Settings nodeSettings(int nodeOrdinal) {
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal))
.put("indices.breaker.request.type", "memory")
.build();
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(TestPlugin.class);
}
public void testLocalClusterAlias() {
long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("1");
indexRequest.source("field", "value");
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
"local", nowInMillis, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
SearchHit[] hits = searchResponse.getHits().getHits();
assertEquals(1, hits.length);
SearchHit hit = hits[0];
assertEquals("local", hit.getClusterAlias());
assertEquals("test", hit.getIndex());
assertEquals("1", hit.getId());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
"", nowInMillis, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
SearchHit[] hits = searchResponse.getHits().getHits();
assertEquals(1, hits.length);
SearchHit hit = hits[0];
assertEquals("", hit.getClusterAlias());
assertEquals("test", hit.getIndex());
assertEquals("1", hit.getId());
}
}
public void testAbsoluteStartMillis() {
{
IndexRequest indexRequest = new IndexRequest("test-1970.01.01");
indexRequest.id("1");
indexRequest.source("date", "1970-01-01");
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
{
IndexRequest indexRequest = new IndexRequest("test-1982.01.01");
indexRequest.id("1");
indexRequest.source("date", "1982-01-01");
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
{
SearchRequest searchRequest = new SearchRequest();
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
}
{
SearchRequest searchRequest = new SearchRequest("<test-{now/d}>");
searchRequest.indicesOptions(IndicesOptions.fromOptions(true, true, true, true));
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(0, searchResponse.getTotalShards());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
searchRequest.indices("<test-{now/d}>");
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date");
rangeQuery.gte("1970-01-01");
rangeQuery.lt("1982-01-01");
sourceBuilder.query(rangeQuery);
searchRequest.source(sourceBuilder);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
}
}
public void testFinalReduce() {
long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
{
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("1");
indexRequest.source("price", 10);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
{
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("2");
indexRequest.source("price", 100);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
client().admin().indices().prepareRefresh("test").get();
SearchRequest originalRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.size(0);
originalRequest.source(source);
TermsAggregationBuilder terms = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.NUMERIC);
terms.field("price");
terms.size(1);
source.aggregation(terms);
{
SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest,
Strings.EMPTY_ARRAY, "remote", nowInMillis, true);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Aggregations aggregations = searchResponse.getAggregations();
LongTerms longTerms = aggregations.get("terms");
assertEquals(1, longTerms.getBuckets().size());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest,
Strings.EMPTY_ARRAY, "remote", nowInMillis, false);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Aggregations aggregations = searchResponse.getAggregations();
LongTerms longTerms = aggregations.get("terms");
assertEquals(2, longTerms.getBuckets().size());
}
}
public void testShardCountLimit() throws Exception { public void testShardCountLimit() throws Exception {
try { try {
@ -103,4 +332,276 @@ public class TransportSearchIT extends ESIntegTestCase {
assertThat(resp.getHits().getTotalHits().value, equalTo(2L)); assertThat(resp.getHits().getTotalHits().value, equalTo(2L));
}); });
} }
public void testCircuitBreakerReduceFail() throws Exception {
int numShards = randomIntBetween(1, 10);
indexSomeDocs("test", numShards, numShards*3);
{
final AtomicArray<Boolean> responses = new AtomicArray<>(10);
final CountDownLatch latch = new CountDownLatch(10);
for (int i = 0; i < 10; i++) {
int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3));
SearchRequest request = client().prepareSearch("test")
.addAggregation(new TestAggregationBuilder("test"))
.setBatchedReduceSize(batchReduceSize)
.request();
final int index = i;
client().search(request, new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {
responses.set(index, true);
latch.countDown();
}
@Override
public void onFailure(Exception e) {
responses.set(index, false);
latch.countDown();
}
});
}
latch.await();
assertThat(responses.asList().size(), equalTo(10));
for (boolean resp : responses.asList()) {
assertTrue(resp);
}
assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L)));
}
try {
Settings settings = Settings.builder()
.put("indices.breaker.request.limit", "1b")
.build();
assertAcked(client().admin().cluster().prepareUpdateSettings().setTransientSettings(settings));
final Client client = client();
assertBusy(() -> {
SearchPhaseExecutionException exc = expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("test")
.addAggregation(new TestAggregationBuilder("test"))
.get());
assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("<reduce_aggs>"));
});
final AtomicArray<Exception> exceptions = new AtomicArray<>(10);
final CountDownLatch latch = new CountDownLatch(10);
for (int i = 0; i < 10; i++) {
int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3));
SearchRequest request = client().prepareSearch("test")
.addAggregation(new TestAggregationBuilder("test"))
.setBatchedReduceSize(batchReduceSize)
.request();
final int index = i;
client().search(request, new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {
latch.countDown();
}
@Override
public void onFailure(Exception exc) {
exceptions.set(index, exc);
latch.countDown();
}
});
}
latch.await();
assertThat(exceptions.asList().size(), equalTo(10));
for (Exception exc : exceptions.asList()) {
assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("<reduce_aggs>"));
}
assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L)));
} finally {
Settings settings = Settings.builder()
.putNull("indices.breaker.request.limit")
.build();
assertAcked(client().admin().cluster().prepareUpdateSettings().setTransientSettings(settings));
}
}
public void testCircuitBreakerFetchFail() throws Exception {
int numShards = randomIntBetween(1, 10);
int numDocs = numShards*10;
indexSomeDocs("boom", numShards, numDocs);
final AtomicArray<Exception> exceptions = new AtomicArray<>(10);
final CountDownLatch latch = new CountDownLatch(10);
for (int i = 0; i < 10; i++) {
int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3));
SearchRequest request = client().prepareSearch("boom")
.setBatchedReduceSize(batchReduceSize)
.setAllowPartialSearchResults(false)
.request();
final int index = i;
client().search(request, new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {
latch.countDown();
}
@Override
public void onFailure(Exception exc) {
exceptions.set(index, exc);
latch.countDown();
}
});
}
latch.await();
assertThat(exceptions.asList().size(), equalTo(10));
for (Exception exc : exceptions.asList()) {
assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("boom"));
}
assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L)));
}
private void indexSomeDocs(String indexName, int numberOfShards, int numberOfDocs) {
createIndex(indexName, Settings.builder().put("index.number_of_shards", numberOfShards).build());
for (int i = 0; i < numberOfDocs; i++) {
IndexResponse indexResponse = client().prepareIndex(indexName, "_doc")
.setSource("number", randomInt())
.get();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
client().admin().indices().prepareRefresh(indexName).get();
}
private long requestBreakerUsed() {
NodesStatsResponse stats = client().admin().cluster().prepareNodesStats()
.addMetric(NodesStatsRequest.Metric.BREAKER.metricName())
.get();
long estimated = 0;
for (NodeStats nodeStats : stats.getNodes()) {
estimated += nodeStats.getBreaker().getStats(CircuitBreaker.REQUEST).getEstimated();
}
return estimated;
}
/**
* A test aggregation that doesn't consume circuit breaker memory when running on shards.
* It is used to test the behavior of the circuit breaker when reducing multiple aggregations
* together (coordinator node).
*/
private static class TestAggregationBuilder extends AbstractAggregationBuilder<TestAggregationBuilder> {
static final String NAME = "test";
private static final ObjectParser<TestAggregationBuilder, String> PARSER =
ObjectParser.fromBuilder(NAME, TestAggregationBuilder::new);
TestAggregationBuilder(String name) {
super(name);
}
TestAggregationBuilder(StreamInput input) throws IOException {
super(input);
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
// noop
}
@Override
protected AggregatorFactory doBuild(QueryShardContext queryShardContext,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
return new AggregatorFactory(name, queryShardContext, parent, subFactoriesBuilder, metadata) {
@Override
protected Aggregator createInternal(SearchContext searchContext,
Aggregator parent,
CardinalityUpperBound cardinality,
Map<String, Object> metadata) throws IOException {
return new TestAggregator(name, parent, searchContext);
}
};
}
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
return builder;
}
@Override
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metadata) {
return new TestAggregationBuilder(name);
}
@Override
public BucketCardinality bucketCardinality() {
return BucketCardinality.NONE;
}
@Override
public String getType() {
return "test";
}
}
/**
* A test aggregator that extends {@link Aggregator} instead of {@link AggregatorBase}
* to avoid tripping the circuit breaker when executing on a shard.
*/
private static class TestAggregator extends Aggregator {
private final String name;
private final Aggregator parent;
private final SearchContext context;
private TestAggregator(String name, Aggregator parent, SearchContext context) {
this.name = name;
this.parent = parent;
this.context = context;
}
@Override
public String name() {
return name;
}
@Override
public SearchContext context() {
return context;
}
@Override
public Aggregator parent() {
return parent;
}
@Override
public Aggregator subAggregator(String name) {
return null;
}
@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
return new InternalAggregation[] {
new InternalMax(name(), Double.NaN, DocValueFormat.RAW, Collections.emptyMap())
};
}
@Override
public InternalAggregation buildEmptyAggregation() {
return new InternalMax(name(), Double.NaN, DocValueFormat.RAW, Collections.emptyMap());
}
@Override
public void close() {}
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
throw new CollectionTerminatedException();
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
@Override
public void preCollection() throws IOException {}
@Override
public void postCollection() throws IOException {}
}
} }

View File

@ -33,6 +33,8 @@ import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
@ -77,7 +79,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
**/ **/
private final BiFunction<String, String, Transport.Connection> nodeIdToConnection; private final BiFunction<String, String, Transport.Connection> nodeIdToConnection;
private final SearchTask task; private final SearchTask task;
final SearchPhaseResults<Result> results; protected final SearchPhaseResults<Result> results;
private final ClusterState clusterState; private final ClusterState clusterState;
private final Map<String, AliasFilter> aliasFilter; private final Map<String, AliasFilter> aliasFilter;
private final Map<String, Float> concreteIndexBoosts; private final Map<String, Float> concreteIndexBoosts;
@ -98,6 +100,8 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>(); private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
private final boolean throttleConcurrentRequests; private final boolean throttleConcurrentRequests;
private final List<Releasable> releasables = new ArrayList<>();
AbstractSearchAsyncAction(String name, Logger logger, SearchTransportService searchTransportService, AbstractSearchAsyncAction(String name, Logger logger, SearchTransportService searchTransportService,
BiFunction<String, String, Transport.Connection> nodeIdToConnection, BiFunction<String, String, Transport.Connection> nodeIdToConnection,
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
@ -133,7 +137,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.executor = executor; this.executor = executor;
this.request = request; this.request = request;
this.task = task; this.task = task;
this.listener = listener; this.listener = ActionListener.runAfter(listener, this::releaseContext);
this.nodeIdToConnection = nodeIdToConnection; this.nodeIdToConnection = nodeIdToConnection;
this.clusterState = clusterState; this.clusterState = clusterState;
this.concreteIndexBoosts = concreteIndexBoosts; this.concreteIndexBoosts = concreteIndexBoosts;
@ -143,6 +147,15 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.clusters = clusters; this.clusters = clusters;
} }
@Override
public void addReleasable(Releasable releasable) {
releasables.add(releasable);
}
public void releaseContext() {
Releasables.close(releasables);
}
/** /**
* Builds how long it took to execute the search. * Builds how long it took to execute the search.
*/ */
@ -529,7 +542,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
ShardSearchFailure[] failures = buildShardFailures(); ShardSearchFailure[] failures = buildShardFailures();
Boolean allowPartialResults = request.allowPartialSearchResults(); Boolean allowPartialResults = request.allowPartialSearchResults();
assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults";
if (request.pointInTimeBuilder() == null && allowPartialResults == false && failures.length > 0) { if (allowPartialResults == false && failures.length > 0) {
raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures));
} else { } else {
final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); final Version minNodeVersion = clusterState.nodes().getMinNodeVersion();
@ -567,6 +580,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
} }
}); });
} }
Releasables.close(releasables);
listener.onFailure(exception); listener.onFailure(exception);
} }

View File

@ -23,6 +23,7 @@ import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.search.SearchService.CanMatchResponse; import org.elasticsearch.search.SearchService.CanMatchResponse;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -76,6 +77,11 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
this.shardsIts = shardsIts; this.shardsIts = shardsIts;
} }
@Override
public void addReleasable(Releasable releasable) {
throw new RuntimeException("cannot add releasable in " + getName() + " phase");
}
@Override @Override
protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarget shard, protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarget shard,
SearchActionListener<CanMatchResponse> listener) { SearchActionListener<CanMatchResponse> listener) {
@ -84,8 +90,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
} }
@Override @Override
protected SearchPhase getNextPhase(SearchPhaseResults<CanMatchResponse> results, protected SearchPhase getNextPhase(SearchPhaseResults<CanMatchResponse> results, SearchPhaseContext context) {
SearchPhaseContext context) {
return phaseFactory.apply(getIterator((CanMatchSearchPhaseResults) results, shardsIts)); return phaseFactory.apply(getIterator((CanMatchSearchPhaseResults) results, shardsIts));
} }

View File

@ -29,7 +29,6 @@ import org.elasticsearch.transport.Transport;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
/** /**
@ -50,18 +49,21 @@ final class DfsQueryPhase extends SearchPhase {
DfsQueryPhase(List<DfsSearchResult> searchResults, DfsQueryPhase(List<DfsSearchResult> searchResults,
AggregatedDfs dfs, AggregatedDfs dfs,
SearchPhaseController searchPhaseController, QueryPhaseResultConsumer queryResult,
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory, Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context, Consumer<Exception> onPartialMergeFailure) { SearchPhaseContext context) {
super("dfs_query"); super("dfs_query");
this.progressListener = context.getTask().getProgressListener(); this.progressListener = context.getTask().getProgressListener();
this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener, this.queryResult = queryResult;
context.getRequest(), context.getNumShards(), onPartialMergeFailure);
this.searchResults = searchResults; this.searchResults = searchResults;
this.dfs = dfs; this.dfs = dfs;
this.nextPhaseFactory = nextPhaseFactory; this.nextPhaseFactory = nextPhaseFactory;
this.context = context; this.context = context;
this.searchTransportService = context.getSearchTransport(); this.searchTransportService = context.getSearchTransport();
// register the release of the query consumer to free up the circuit breaker memory
// at the end of the search
context.addReleasable(queryResult);
} }
@Override @Override

View File

@ -23,8 +23,11 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats; import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
@ -51,13 +54,16 @@ import static org.elasticsearch.action.search.SearchPhaseController.setShardInde
/** /**
* A {@link ArraySearchPhaseResults} implementation that incrementally reduces aggregation results * A {@link ArraySearchPhaseResults} implementation that incrementally reduces aggregation results
* as shard results are consumed. * as shard results are consumed.
* This implementation can be configured to batch up a certain amount of results and reduce * This implementation adds the memory that it used to save and reduce the results of shard aggregations
* them asynchronously in the provided {@link Executor} iff the buffer is exhausted. * in the {@link CircuitBreaker#REQUEST} circuit breaker. Before any partial or final reduce, the memory
* needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is thrown if it
* exceeds the maximum memory allowed in this breaker.
*/ */
public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> { public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> implements Releasable {
private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class); private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class);
private final Executor executor; private final Executor executor;
private final CircuitBreaker circuitBreaker;
private final SearchPhaseController controller; private final SearchPhaseController controller;
private final SearchProgressListener progressListener; private final SearchProgressListener progressListener;
private final ReduceContextBuilder aggReduceContextBuilder; private final ReduceContextBuilder aggReduceContextBuilder;
@ -71,15 +77,13 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
private final PendingMerges pendingMerges; private final PendingMerges pendingMerges;
private final Consumer<Exception> onPartialMergeFailure; private final Consumer<Exception> onPartialMergeFailure;
private volatile long aggsMaxBufferSize;
private volatile long aggsCurrentBufferSize;
/** /**
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
* as shard results are consumed. * as shard results are consumed.
*/ */
public QueryPhaseResultConsumer(SearchRequest request, public QueryPhaseResultConsumer(SearchRequest request,
Executor executor, Executor executor,
CircuitBreaker circuitBreaker,
SearchPhaseController controller, SearchPhaseController controller,
SearchProgressListener progressListener, SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry, NamedWriteableRegistry namedWriteableRegistry,
@ -87,6 +91,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
Consumer<Exception> onPartialMergeFailure) { Consumer<Exception> onPartialMergeFailure) {
super(expectedResultSize); super(expectedResultSize);
this.executor = executor; this.executor = executor;
this.circuitBreaker = circuitBreaker;
this.controller = controller; this.controller = controller;
this.progressListener = progressListener; this.progressListener = progressListener;
this.aggReduceContextBuilder = controller.getReduceContext(request); this.aggReduceContextBuilder = controller.getReduceContext(request);
@ -94,11 +99,17 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
this.topNSize = getTopDocsSize(request); this.topNSize = getTopDocsSize(request);
this.performFinalReduce = request.isFinalReduce(); this.performFinalReduce = request.isFinalReduce();
this.onPartialMergeFailure = onPartialMergeFailure; this.onPartialMergeFailure = onPartialMergeFailure;
SearchSourceBuilder source = request.source(); SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0; this.hasTopDocs = source == null || source.size() != 0;
this.hasAggs = source != null && source.aggregations() != null; this.hasAggs = source != null && source.aggregations() != null;
int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(bufferSize, request.resolveTrackTotalHitsUpTo()); this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
}
@Override
public void close() {
Releasables.close(pendingMerges);
} }
@Override @Override
@ -117,28 +128,35 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
throw pendingMerges.getFailure(); throw pendingMerges.getFailure();
} }
logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize);
// ensure consistent ordering // ensure consistent ordering
pendingMerges.sortBuffer(); pendingMerges.sortBuffer();
final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats();
final List<TopDocs> topDocsList = pendingMerges.consumeTopDocs(); final List<TopDocs> topDocsList = pendingMerges.consumeTopDocs();
final List<InternalAggregations> aggsList = pendingMerges.consumeAggs(); final List<InternalAggregations> aggsList = pendingMerges.consumeAggs();
long breakerSize = pendingMerges.circuitBreakerBytes;
if (hasAggs) {
// Add an estimate of the final reduce size
breakerSize = pendingMerges.addEstimateAndMaybeBreak(pendingMerges.estimateRamBytesUsedForReduce(breakerSize));
}
SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList, SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList,
topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce); topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce);
if (hasAggs) {
// Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result
long finalSize = reducePhase.aggregations.getSerializedSize() - breakerSize;
pendingMerges.addWithoutBreaking(finalSize);
logger.trace("aggs final reduction [{}] max [{}]",
pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize);
}
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()), progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase; return reducePhase;
} }
private MergeResult partialReduce(MergeTask task, private MergeResult partialReduce(QuerySearchResult[] toConsume,
List<SearchShard> emptyResults,
TopDocsStats topDocsStats, TopDocsStats topDocsStats,
MergeResult lastMerge, MergeResult lastMerge,
int numReducePhases) { int numReducePhases) {
final QuerySearchResult[] toConsume = task.consumeBuffer();
if (toConsume == null) {
// the task is cancelled
return null;
}
// ensure consistent ordering // ensure consistent ordering
Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex)); Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex));
@ -164,27 +182,20 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
newTopDocs = null; newTopDocs = null;
} }
final DelayableWriteable.Serialized<InternalAggregations> newAggs; final InternalAggregations newAggs;
if (hasAggs) { if (hasAggs) {
List<InternalAggregations> aggsList = new ArrayList<>(); List<InternalAggregations> aggsList = new ArrayList<>();
if (lastMerge != null) { if (lastMerge != null) {
aggsList.add(lastMerge.reducedAggs.expand()); aggsList.add(lastMerge.reducedAggs);
} }
for (QuerySearchResult result : toConsume) { for (QuerySearchResult result : toConsume) {
aggsList.add(result.consumeAggs().expand()); aggsList.add(result.consumeAggs().expand());
} }
InternalAggregations result = InternalAggregations.topLevelReduce(aggsList, newAggs = InternalAggregations.topLevelReduce(aggsList, aggReduceContextBuilder.forPartialReduction());
aggReduceContextBuilder.forPartialReduction());
newAggs = DelayableWriteable.referencing(result).asSerialized(InternalAggregations::readFrom, namedWriteableRegistry);
long previousBufferSize = aggsCurrentBufferSize;
aggsCurrentBufferSize = newAggs.ramBytesUsed();
aggsMaxBufferSize = Math.max(aggsCurrentBufferSize, aggsMaxBufferSize);
logger.trace("aggs partial reduction [{}->{}] max [{}]",
previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize);
} else { } else {
newAggs = null; newAggs = null;
} }
List<SearchShard> processedShards = new ArrayList<>(task.emptyResults); List<SearchShard> processedShards = new ArrayList<>(emptyResults);
if (lastMerge != null) { if (lastMerge != null) {
processedShards.addAll(lastMerge.processedShards); processedShards.addAll(lastMerge.processedShards);
} }
@ -193,49 +204,109 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
} }
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
return new MergeResult(processedShards, newTopDocs, newAggs); // we leave the results un-serialized because serializing is slow but we compute the serialized
// size as an estimate of the memory used by the newly reduced aggregations.
long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0;
return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0);
} }
public int getNumReducePhases() { public int getNumReducePhases() {
return pendingMerges.numReducePhases; return pendingMerges.numReducePhases;
} }
private class PendingMerges { private class PendingMerges implements Releasable {
private final int bufferSize; private final int batchReduceSize;
private final List<QuerySearchResult> buffer = new ArrayList<>();
private int index;
private final QuerySearchResult[] buffer;
private final List<SearchShard> emptyResults = new ArrayList<>(); private final List<SearchShard> emptyResults = new ArrayList<>();
// the memory that is accounted in the circuit breaker for this consumer
private volatile long circuitBreakerBytes;
// the memory that is currently used in the buffer
private volatile long aggsCurrentBufferSize;
private volatile long maxAggsCurrentBufferSize = 0;
private final TopDocsStats topDocsStats;
private MergeResult mergeResult;
private final ArrayDeque<MergeTask> queue = new ArrayDeque<>(); private final ArrayDeque<MergeTask> queue = new ArrayDeque<>();
private final AtomicReference<MergeTask> runningTask = new AtomicReference<>(); private final AtomicReference<MergeTask> runningTask = new AtomicReference<>();
private final AtomicReference<Exception> failure = new AtomicReference<>(); private final AtomicReference<Exception> failure = new AtomicReference<>();
private boolean hasPartialReduce; private final TopDocsStats topDocsStats;
private int numReducePhases; private volatile MergeResult mergeResult;
private volatile boolean hasPartialReduce;
private volatile int numReducePhases;
PendingMerges(int bufferSize, int trackTotalHitsUpTo) { PendingMerges(int batchReduceSize, int trackTotalHitsUpTo) {
this.bufferSize = bufferSize; this.batchReduceSize = batchReduceSize;
this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo); this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo);
this.buffer = new QuerySearchResult[bufferSize];
} }
public boolean hasFailure() { @Override
public synchronized void close() {
assert hasPendingMerges() == false : "cannot close with partial reduce in-flight";
if (hasFailure()) {
assert circuitBreakerBytes == 0;
return;
}
assert circuitBreakerBytes >= 0;
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
}
synchronized Exception getFailure() {
return failure.get();
}
boolean hasFailure() {
return failure.get() != null; return failure.get() != null;
} }
public synchronized boolean hasPendingMerges() { boolean hasPendingMerges() {
return queue.isEmpty() == false || runningTask.get() != null; return queue.isEmpty() == false || runningTask.get() != null;
} }
public synchronized void sortBuffer() { void sortBuffer() {
if (index > 0) { if (buffer.size() > 0) {
Arrays.sort(buffer, 0, index, Comparator.comparingInt(QuerySearchResult::getShardIndex)); Collections.sort(buffer, Comparator.comparingInt(QuerySearchResult::getShardIndex));
} }
} }
synchronized long addWithoutBreaking(long size) {
circuitBreaker.addWithoutBreaking(size);
circuitBreakerBytes += size;
maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes);
return circuitBreakerBytes;
}
synchronized long addEstimateAndMaybeBreak(long estimatedSize) {
circuitBreaker.addEstimateBytesAndMaybeBreak(estimatedSize, "<reduce_aggs>");
circuitBreakerBytes += estimatedSize;
maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes);
return circuitBreakerBytes;
}
/**
* Returns the size of the serialized aggregation that is contained in the
* provided {@link QuerySearchResult}.
*/
long ramBytesUsedQueryResult(QuerySearchResult result) {
if (hasAggs == false) {
return 0;
}
return result.aggregations()
.asSerialized(InternalAggregations::readFrom, namedWriteableRegistry)
.ramBytesUsed();
}
/**
* Returns an estimation of the size that a reduce of the provided size
* would take on memory.
* This size is estimated as roughly 1.5 times the size of the serialized
* aggregations that need to be reduced. This estimation can be completely
* off for some aggregations but it is corrected with the real size after
* the reduce completes.
*/
long estimateRamBytesUsedForReduce(long size) {
return Math.round(1.5d * size - size);
}
public void consume(QuerySearchResult result, Runnable next) { public void consume(QuerySearchResult result, Runnable next) {
boolean executeNextImmediately = true; boolean executeNextImmediately = true;
synchronized (this) { synchronized (this) {
@ -247,20 +318,24 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
} }
} else { } else {
// add one if a partial merge is pending // add one if a partial merge is pending
int size = index + (hasPartialReduce ? 1 : 0); int size = buffer.size() + (hasPartialReduce ? 1 : 0);
if (size >= bufferSize) { if (size >= batchReduceSize) {
hasPartialReduce = true; hasPartialReduce = true;
executeNextImmediately = false; executeNextImmediately = false;
QuerySearchResult[] clone = new QuerySearchResult[index]; QuerySearchResult[] clone = buffer.stream().toArray(QuerySearchResult[]::new);
System.arraycopy(buffer, 0, clone, 0, index); MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next);
MergeTask task = new MergeTask(clone, new ArrayList<>(emptyResults), next); aggsCurrentBufferSize = 0;
Arrays.fill(buffer, null); buffer.clear();
emptyResults.clear(); emptyResults.clear();
index = 0;
queue.add(task); queue.add(task);
tryExecuteNext(); tryExecuteNext();
} }
buffer[index++] = result; if (hasAggs) {
long aggsSize = ramBytesUsedQueryResult(result);
addWithoutBreaking(aggsSize);
aggsCurrentBufferSize += aggsSize;
}
buffer.add(result);
} }
} }
if (executeNextImmediately) { if (executeNextImmediately) {
@ -268,56 +343,85 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
} }
} }
private void onMergeFailure(Exception exc) { private synchronized void onMergeFailure(Exception exc) {
synchronized (this) { if (hasFailure()) {
if (failure.get() != null) { assert circuitBreakerBytes == 0;
return; return;
} }
failure.compareAndSet(null, exc); assert circuitBreakerBytes >= 0;
MergeTask task = runningTask.get(); if (circuitBreakerBytes > 0) {
runningTask.compareAndSet(task, null); // make sure that we reset the circuit breaker
onPartialMergeFailure.accept(exc); circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
List<MergeTask> toCancel = new ArrayList<>(); circuitBreakerBytes = 0;
if (task != null) { }
toCancel.add(task); failure.compareAndSet(null, exc);
} MergeTask task = runningTask.get();
toCancel.addAll(queue); runningTask.compareAndSet(task, null);
queue.clear(); onPartialMergeFailure.accept(exc);
mergeResult = null; List<MergeTask> toCancels = new ArrayList<>();
toCancel.stream().forEach(MergeTask::cancel); if (task != null) {
toCancels.add(task);
}
queue.stream().forEach(toCancels::add);
queue.clear();
mergeResult = null;
for (MergeTask toCancel : toCancels) {
toCancel.cancel();
} }
} }
private void onAfterMerge(MergeTask task, MergeResult newResult) { private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) {
synchronized (this) { synchronized (this) {
if (hasFailure()) {
return;
}
runningTask.compareAndSet(task, null); runningTask.compareAndSet(task, null);
mergeResult = newResult; mergeResult = newResult;
if (hasAggs) {
// Update the circuit breaker to remove the size of the source aggregations
// and replace the estimation with the serialized size of the newly reduced result.
long newSize = mergeResult.estimatedSize - estimatedSize;
addWithoutBreaking(newSize);
logger.trace("aggs partial reduction [{}->{}] max [{}]",
estimatedSize, mergeResult.estimatedSize, maxAggsCurrentBufferSize);
}
task.consumeListener();
} }
task.consumeListener();
} }
private void tryExecuteNext() { private void tryExecuteNext() {
final MergeTask task; final MergeTask task;
synchronized (this) { synchronized (this) {
if (queue.isEmpty() if (queue.isEmpty()
|| failure.get() != null || hasFailure()
|| runningTask.get() != null) { || runningTask.get() != null) {
return; return;
} }
task = queue.poll(); task = queue.poll();
runningTask.compareAndSet(null, task); runningTask.compareAndSet(null, task);
} }
executor.execute(new AbstractRunnable() { executor.execute(new AbstractRunnable() {
@Override @Override
protected void doRun() { protected void doRun() {
final MergeResult thisMergeResult = mergeResult;
long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + task.aggsBufferSize;
final MergeResult newMerge; final MergeResult newMerge;
try { try {
newMerge = partialReduce(task, topDocsStats, mergeResult, ++numReducePhases); final QuerySearchResult[] toConsume = task.consumeBuffer();
if (toConsume == null) {
return;
}
long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize);
addEstimateAndMaybeBreak(estimatedMergeSize);
estimatedTotalSize += estimatedMergeSize;
++ numReducePhases;
newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases);
} catch (Exception t) { } catch (Exception t) {
onMergeFailure(t); onMergeFailure(t);
return; return;
} }
onAfterMerge(task, newMerge); onAfterMerge(task, newMerge, estimatedTotalSize);
tryExecuteNext(); tryExecuteNext();
} }
@ -328,15 +432,14 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
}); });
} }
public TopDocsStats consumeTopDocsStats() { public synchronized TopDocsStats consumeTopDocsStats() {
for (int i = 0; i < index; i++) { for (QuerySearchResult result : buffer) {
QuerySearchResult result = buffer[i];
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
} }
return topDocsStats; return topDocsStats;
} }
public List<TopDocs> consumeTopDocs() { public synchronized List<TopDocs> consumeTopDocs() {
if (hasTopDocs == false) { if (hasTopDocs == false) {
return Collections.emptyList(); return Collections.emptyList();
} }
@ -344,8 +447,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
if (mergeResult != null) { if (mergeResult != null) {
topDocsList.add(mergeResult.reducedTopDocs); topDocsList.add(mergeResult.reducedTopDocs);
} }
for (int i = 0; i < index; i++) { for (QuerySearchResult result : buffer) {
QuerySearchResult result = buffer[i];
TopDocsAndMaxScore topDocs = result.consumeTopDocs(); TopDocsAndMaxScore topDocs = result.consumeTopDocs();
setShardIndex(topDocs.topDocs, result.getShardIndex()); setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs); topDocsList.add(topDocs.topDocs);
@ -353,46 +455,45 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
return topDocsList; return topDocsList;
} }
public List<InternalAggregations> consumeAggs() { public synchronized List<InternalAggregations> consumeAggs() {
if (hasAggs == false) { if (hasAggs == false) {
return Collections.emptyList(); return Collections.emptyList();
} }
List<InternalAggregations> aggsList = new ArrayList<>(); List<InternalAggregations> aggsList = new ArrayList<>();
if (mergeResult != null) { if (mergeResult != null) {
aggsList.add(mergeResult.reducedAggs.expand()); aggsList.add(mergeResult.reducedAggs);
} }
for (int i = 0; i < index; i++) { for (QuerySearchResult result : buffer) {
QuerySearchResult result = buffer[i];
aggsList.add(result.consumeAggs().expand()); aggsList.add(result.consumeAggs().expand());
} }
return aggsList; return aggsList;
} }
public Exception getFailure() {
return failure.get();
}
} }
private static class MergeResult { private static class MergeResult {
private final List<SearchShard> processedShards; private final List<SearchShard> processedShards;
private final TopDocs reducedTopDocs; private final TopDocs reducedTopDocs;
private final DelayableWriteable.Serialized<InternalAggregations> reducedAggs; private final InternalAggregations reducedAggs;
private final long estimatedSize;
private MergeResult(List<SearchShard> processedShards, TopDocs reducedTopDocs, private MergeResult(List<SearchShard> processedShards, TopDocs reducedTopDocs,
DelayableWriteable.Serialized<InternalAggregations> reducedAggs) { InternalAggregations reducedAggs, long estimatedSize) {
this.processedShards = processedShards; this.processedShards = processedShards;
this.reducedTopDocs = reducedTopDocs; this.reducedTopDocs = reducedTopDocs;
this.reducedAggs = reducedAggs; this.reducedAggs = reducedAggs;
this.estimatedSize = estimatedSize;
} }
} }
private static class MergeTask { private static class MergeTask {
private final List<SearchShard> emptyResults; private final List<SearchShard> emptyResults;
private QuerySearchResult[] buffer; private QuerySearchResult[] buffer;
private long aggsBufferSize;
private Runnable next; private Runnable next;
private MergeTask(QuerySearchResult[] buffer, List<SearchShard> emptyResults, Runnable next) { private MergeTask(QuerySearchResult[] buffer, long aggsBufferSize, List<SearchShard> emptyResults, Runnable next) {
this.buffer = buffer; this.buffer = buffer;
this.aggsBufferSize = aggsBufferSize;
this.emptyResults = emptyResults; this.emptyResults = emptyResults;
this.next = next; this.next = next;
} }
@ -403,7 +504,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
return toRet; return toRet;
} }
public synchronized void consumeListener() { public void consumeListener() {
if (next != null) { if (next != null) {
next.run(); next.run();
next = null; next = null;

View File

@ -35,29 +35,29 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer;
final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<DfsSearchResult> { final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<DfsSearchResult> {
private final SearchPhaseController searchPhaseController; private final SearchPhaseController searchPhaseController;
private final Consumer<Exception> onPartialMergeFailure;
private final QueryPhaseResultConsumer queryPhaseResultConsumer;
SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService, SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService,
final BiFunction<String, String, Transport.Connection> nodeIdToConnection, final BiFunction<String, String, Transport.Connection> nodeIdToConnection,
final Map<String, AliasFilter> aliasFilter, final Map<String, AliasFilter> aliasFilter,
final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings, final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings,
final SearchPhaseController searchPhaseController, final Executor executor, final SearchPhaseController searchPhaseController, final Executor executor,
final QueryPhaseResultConsumer queryPhaseResultConsumer,
final SearchRequest request, final ActionListener<SearchResponse> listener, final SearchRequest request, final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts, final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider, final TransportSearchAction.SearchTimeProvider timeProvider,
final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters, final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) {
Consumer<Exception> onPartialMergeFailure) {
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
executor, request, listener, executor, request, listener,
shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()),
request.getMaxConcurrentShardRequests(), clusters); request.getMaxConcurrentShardRequests(), clusters);
this.queryPhaseResultConsumer = queryPhaseResultConsumer;
this.searchPhaseController = searchPhaseController; this.searchPhaseController = searchPhaseController;
this.onPartialMergeFailure = onPartialMergeFailure;
SearchProgressListener progressListener = task.getProgressListener(); SearchProgressListener progressListener = task.getProgressListener();
SearchSourceBuilder sourceBuilder = request.source(); SearchSourceBuilder sourceBuilder = request.source();
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
@ -72,11 +72,12 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
} }
@Override @Override
protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> results, final SearchPhaseContext context) { protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> results, SearchPhaseContext context) {
final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList(); final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
final AggregatedDfs aggregatedDfs = searchPhaseController.aggregateDfs(dfsSearchResults); final AggregatedDfs aggregatedDfs = searchPhaseController.aggregateDfs(dfsSearchResults);
return new DfsQueryPhase(dfsSearchResults, aggregatedDfs, searchPhaseController, (queryResults) -> return new DfsQueryPhase(dfsSearchResults, aggregatedDfs, queryPhaseResultConsumer,
new FetchSearchPhase(queryResults, searchPhaseController, aggregatedDfs, context), context, onPartialMergeFailure); (queryResults) -> new FetchSearchPhase(queryResults, searchPhaseController, aggregatedDfs, context),
context);
} }
} }

View File

@ -21,6 +21,7 @@ package org.elasticsearch.action.search;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
@ -123,4 +124,9 @@ interface SearchPhaseContext extends Executor {
* a response is returned to the user indicating that all shards have failed. * a response is returned to the user indicating that all shards have failed.
*/ */
void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase); void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase);
/**
* Registers a {@link Releasable} that will be closed when the search request finishes or fails.
*/
void addReleasable(Releasable releasable);
} }

View File

@ -34,6 +34,7 @@ import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.TotalHits.Relation;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs; import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.collect.HppcMaps; import org.elasticsearch.common.collect.HppcMaps;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
@ -563,14 +564,16 @@ public final class SearchPhaseController {
} }
/** /**
* Returns a new {@link QueryPhaseResultConsumer} instance. This might return an instance that reduces search responses incrementally. * Returns a new {@link QueryPhaseResultConsumer} instance that reduces search responses incrementally.
*/ */
QueryPhaseResultConsumer newSearchPhaseResults(Executor executor, QueryPhaseResultConsumer newSearchPhaseResults(Executor executor,
CircuitBreaker circuitBreaker,
SearchProgressListener listener, SearchProgressListener listener,
SearchRequest request, SearchRequest request,
int numShards, int numShards,
Consumer<Exception> onPartialMergeFailure) { Consumer<Exception> onPartialMergeFailure) {
return new QueryPhaseResultConsumer(request, executor, this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure); return new QueryPhaseResultConsumer(request, executor, circuitBreaker,
this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure);
} }
static final class TopDocsStats { static final class TopDocsStats {

View File

@ -25,7 +25,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchResponse.Clusters; import org.elasticsearch.action.search.SearchResponse.Clusters;
import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalAggregations;
@ -78,11 +77,10 @@ public abstract class SearchProgressListener {
* *
* @param shards The list of shards that are part of this reduce. * @param shards The list of shards that are part of this reduce.
* @param totalHits The total number of hits in this reduce. * @param totalHits The total number of hits in this reduce.
* @param aggs The partial result for aggregations stored in serialized form. * @param aggs The partial result for aggregations.
* @param reducePhase The version number for this reduce. * @param reducePhase The version number for this reduce.
*/ */
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {}
/** /**
* Executed once when the final reduce is created. * Executed once when the final reduce is created.
@ -137,8 +135,7 @@ public abstract class SearchProgressListener {
} }
} }
final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
try { try {
onPartialReduce(shards, totalHits, aggs, reducePhase); onPartialReduce(shards, totalHits, aggs, reducePhase);
} catch (Exception e) { } catch (Exception e) {

View File

@ -26,7 +26,6 @@ import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.internal.ShardSearchRequest;
@ -37,7 +36,6 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer;
import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize;
@ -56,22 +54,26 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
final Map<String, AliasFilter> aliasFilter, final Map<String, AliasFilter> aliasFilter,
final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings, final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings,
final SearchPhaseController searchPhaseController, final Executor executor, final SearchPhaseController searchPhaseController, final Executor executor,
final SearchRequest request, final ActionListener<SearchResponse> listener, final QueryPhaseResultConsumer resultConsumer, final SearchRequest request,
final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts, final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider, final TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) {
Consumer<Exception> onPartialMergeFailure) {
super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
executor, request, listener, shardsIts, timeProvider, clusterState, task, executor, request, listener, shardsIts, timeProvider, clusterState, task,
searchPhaseController.newSearchPhaseResults(executor, task.getProgressListener(), resultConsumer, request.getMaxConcurrentShardRequests(), clusters);
request, shardsIts.size(), onPartialMergeFailure), request.getMaxConcurrentShardRequests(), clusters);
this.topDocsSize = getTopDocsSize(request); this.topDocsSize = getTopDocsSize(request);
this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo(); this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
this.searchPhaseController = searchPhaseController; this.searchPhaseController = searchPhaseController;
this.progressListener = task.getProgressListener(); this.progressListener = task.getProgressListener();
final SearchSourceBuilder sourceBuilder = request.source();
// register the release of the query consumer to free up the circuit breaker memory
// at the end of the search
addReleasable(resultConsumer);
boolean hasFetchPhase = request.source() == null ? true : request.source().size() > 0;
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0); SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, hasFetchPhase);
} }
protected void executePhaseOnShard(final SearchShardIterator shardIt, protected void executePhaseOnShard(final SearchShardIterator shardIt,
@ -108,8 +110,8 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
} }
@Override @Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, final SearchPhaseContext context) { protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return new FetchSearchPhase(results, searchPhaseController, null, context); return new FetchSearchPhase(results, searchPhaseController, null, this);
} }
private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {

View File

@ -43,6 +43,7 @@ import org.elasticsearch.cluster.routing.OperationRouting;
import org.elasticsearch.cluster.routing.ShardIterator; import org.elasticsearch.cluster.routing.ShardIterator;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -55,6 +56,7 @@ import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.index.Index; import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
@ -115,10 +117,12 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
private final SearchService searchService; private final SearchService searchService;
private final IndexNameExpressionResolver indexNameExpressionResolver; private final IndexNameExpressionResolver indexNameExpressionResolver;
private final NamedWriteableRegistry namedWriteableRegistry; private final NamedWriteableRegistry namedWriteableRegistry;
private final CircuitBreaker circuitBreaker;
@Inject @Inject
public TransportSearchAction(NodeClient client, public TransportSearchAction(NodeClient client,
ThreadPool threadPool, ThreadPool threadPool,
CircuitBreakerService circuitBreakerService,
TransportService transportService, TransportService transportService,
SearchService searchService, SearchService searchService,
SearchTransportService searchTransportService, SearchTransportService searchTransportService,
@ -130,6 +134,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
super(SearchAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchRequest>) SearchRequest::new); super(SearchAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchRequest>) SearchRequest::new);
this.client = client; this.client = client;
this.threadPool = threadPool; this.threadPool = threadPool;
this.circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST);
this.searchPhaseController = searchPhaseController; this.searchPhaseController = searchPhaseController;
this.searchTransportService = searchTransportService; this.searchTransportService = searchTransportService;
this.remoteClusterService = searchTransportService.getRemoteClusterService(); this.remoteClusterService = searchTransportService.getRemoteClusterService();
@ -796,17 +801,19 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
}; };
}, clusters); }, clusters);
} else { } else {
final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(executor,
circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(), exc -> cancelTask(task, exc));
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction; AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
switch (searchRequest.searchType()) { switch (searchRequest.searchType()) {
case DFS_QUERY_THEN_FETCH: case DFS_QUERY_THEN_FETCH:
searchAsyncAction = new SearchDfsQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, searchAsyncAction = new SearchDfsQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup,
aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController,
shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); executor, queryResultConsumer, searchRequest, listener, shardIterators, timeProvider, clusterState, task, clusters);
break; break;
case QUERY_THEN_FETCH: case QUERY_THEN_FETCH:
searchAsyncAction = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, searchAsyncAction = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup,
aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, queryResultConsumer,
shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); searchRequest, listener, shardIterators, timeProvider, clusterState, task, clusters);
break; break;
default: default:
throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]"); throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]");

View File

@ -272,4 +272,47 @@ public final class InternalAggregations extends Aggregations implements Writeabl
public static InternalAggregations reduce(List<InternalAggregations> aggregationsList, ReduceContext context) { public static InternalAggregations reduce(List<InternalAggregations> aggregationsList, ReduceContext context) {
return reduce(aggregationsList, context, InternalAggregations::from); return reduce(aggregationsList, context, InternalAggregations::from);
} }
/**
* Returns the number of bytes required to serialize these aggregations in binary form.
*/
public long getSerializedSize() {
try (CountingStreamOutput out = new CountingStreamOutput()) {
out.setVersion(Version.CURRENT);
writeTo(out);
return out.size;
} catch (IOException exc) {
// should never happen
throw new RuntimeException(exc);
}
}
private static class CountingStreamOutput extends StreamOutput {
long size = 0;
@Override
public void writeByte(byte b) throws IOException {
++ size;
}
@Override
public void writeBytes(byte[] b, int offset, int length) throws IOException {
size += length;
}
@Override
public void flush() throws IOException {}
@Override
public void close() throws IOException {}
@Override
public void reset() throws IOException {
size = 0;
}
public long length() {
return size;
}
}
} }

View File

@ -96,7 +96,7 @@ public class AbstractSearchAsyncActionTests extends ESTestCase {
results, request.getMaxConcurrentShardRequests(), results, request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY) { SearchResponse.Clusters.EMPTY) {
@Override @Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, final SearchPhaseContext context) { protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return null; return null;
} }

View File

@ -25,8 +25,11 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.MockDirectoryWrapper; import org.apache.lucene.store.MockDirectoryWrapper;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
@ -86,15 +89,19 @@ public class DfsQueryPhaseTests extends ESTestCase {
} }
} }
}; };
SearchPhaseController searchPhaseController = searchPhaseController();
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
mockSearchPhaseContext.searchTransport = searchTransportService; mockSearchPhaseContext.searchTransport = searchTransportService;
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest,
results.length(), exc -> {});
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer,
(response) -> new SearchPhase("test") { (response) -> new SearchPhase("test") {
@Override @Override
public void run() throws IOException { public void run() throws IOException {
responseRef.set(response.results); responseRef.set(response.results);
} }
}, mockSearchPhaseContext, exc -> {}); }, mockSearchPhaseContext);
assertEquals("dfs_query", phase.getName()); assertEquals("dfs_query", phase.getName());
phase.run(); phase.run();
mockSearchPhaseContext.assertNoFailure(); mockSearchPhaseContext.assertNoFailure();
@ -141,15 +148,19 @@ public class DfsQueryPhaseTests extends ESTestCase {
} }
} }
}; };
SearchPhaseController searchPhaseController = searchPhaseController();
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
mockSearchPhaseContext.searchTransport = searchTransportService; mockSearchPhaseContext.searchTransport = searchTransportService;
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest,
results.length(), exc -> {});
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer,
(response) -> new SearchPhase("test") { (response) -> new SearchPhase("test") {
@Override @Override
public void run() throws IOException { public void run() throws IOException {
responseRef.set(response.results); responseRef.set(response.results);
} }
}, mockSearchPhaseContext, exc -> {}); }, mockSearchPhaseContext);
assertEquals("dfs_query", phase.getName()); assertEquals("dfs_query", phase.getName());
phase.run(); phase.run();
mockSearchPhaseContext.assertNoFailure(); mockSearchPhaseContext.assertNoFailure();
@ -198,15 +209,19 @@ public class DfsQueryPhaseTests extends ESTestCase {
} }
} }
}; };
SearchPhaseController searchPhaseController = searchPhaseController();
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
mockSearchPhaseContext.searchTransport = searchTransportService; mockSearchPhaseContext.searchTransport = searchTransportService;
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest,
results.length(), exc -> {});
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer,
(response) -> new SearchPhase("test") { (response) -> new SearchPhase("test") {
@Override @Override
public void run() throws IOException { public void run() throws IOException {
responseRef.set(response.results); responseRef.set(response.results);
} }
}, mockSearchPhaseContext, exc -> {}); }, mockSearchPhaseContext);
assertEquals("dfs_query", phase.getName()); assertEquals("dfs_query", phase.getName());
expectThrows(UncheckedIOException.class, phase::run); expectThrows(UncheckedIOException.class, phase::run);
assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts

View File

@ -24,6 +24,8 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.MockDirectoryWrapper; import org.apache.lucene.store.MockDirectoryWrapper;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
@ -43,8 +45,6 @@ import org.elasticsearch.transport.Transport;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.action.search.SearchProgressListener.NOOP;
public class FetchSearchPhaseTests extends ESTestCase { public class FetchSearchPhaseTests extends ESTestCase {
public void testShortcutQueryAndFetchOptimization() { public void testShortcutQueryAndFetchOptimization() {
@ -52,7 +52,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
mockSearchPhaseContext.getRequest(), 1, exc -> {});
boolean hasHits = randomBoolean(); boolean hasHits = randomBoolean();
final int numHits; final int numHits;
if (hasHits) { if (hasHits) {
@ -96,7 +97,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController( SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10); int resultSetSize = randomIntBetween(2, 10);
ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx1, new SearchShardTarget("node1", new ShardId("test", "na", 0), QuerySearchResult queryResult = new QuerySearchResult(ctx1, new SearchShardTarget("node1", new ShardId("test", "na", 0),
@ -157,7 +159,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController( SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10); int resultSetSize = randomIntBetween(2, 10);
final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx, QuerySearchResult queryResult = new QuerySearchResult(ctx,
@ -220,7 +223,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController( SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits);
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP, QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
mockSearchPhaseContext.getRequest(), numHits, exc -> {}); mockSearchPhaseContext.getRequest(), numHits, exc -> {});
for (int i = 0; i < numHits; i++) { for (int i = 0; i < numHits; i++) {
QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", i), QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", i),
@ -279,7 +283,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
QueryPhaseResultConsumer results = QueryPhaseResultConsumer results =
controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = randomIntBetween(2, 10); int resultSetSize = randomIntBetween(2, 10);
QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", 123), QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", 123),
new SearchShardTarget("node1", new ShardId("test", "na", 0), new SearchShardTarget("node1", new ShardId("test", "na", 0),
@ -337,7 +342,8 @@ public class FetchSearchPhaseTests extends ESTestCase {
SearchPhaseController controller = new SearchPhaseController( SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder());
QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(),
NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
mockSearchPhaseContext.getRequest(), 2, exc -> {});
int resultSetSize = 1; int resultSetSize = 1;
final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx1, QuerySearchResult queryResult = new QuerySearchResult(ctx1,

View File

@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
@ -131,6 +132,11 @@ public final class MockSearchPhaseContext implements SearchPhaseContext {
} }
} }
@Override
public void addReleasable(Releasable releasable) {
// Noop
}
@Override @Override
public void execute(Runnable command) { public void execute(Runnable command) {
command.run(); command.run();

View File

@ -23,7 +23,8 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
@ -93,8 +94,9 @@ public class QueryPhaseResultConsumerTests extends ESTestCase {
SearchRequest searchRequest = new SearchRequest("index"); SearchRequest searchRequest = new SearchRequest("index");
searchRequest.setBatchedReduceSize(2); searchRequest.setBatchedReduceSize(2);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>(); AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController, QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor,
searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { new NoopCircuitBreaker(CircuitBreaker.REQUEST), searchPhaseController, searchProgressListener,
writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev); curr.addSuppressed(prev);
return curr; return curr;
})); }));
@ -140,7 +142,7 @@ public class QueryPhaseResultConsumerTests extends ESTestCase {
@Override @Override
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) { InternalAggregations aggs, int reducePhase) {
onPartialReduce.incrementAndGet(); onPartialReduce.incrementAndGet();
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }

View File

@ -460,8 +460,7 @@ public class SearchAsyncActionTests extends ESTestCase {
} }
@Override @Override
protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results, protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results, SearchPhaseContext context) {
SearchPhaseContext context) {
return new SearchPhase("test") { return new SearchPhase("test") {
@Override @Override
public void run() { public void run() {

View File

@ -33,10 +33,10 @@ import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -45,7 +45,6 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor; import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -77,7 +76,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -95,7 +93,6 @@ import java.util.stream.Stream;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.elasticsearch.action.search.SearchProgressListener.NOOP;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@ -111,9 +108,9 @@ public class SearchPhaseControllerTests extends ESTestCase {
@Override @Override
protected NamedWriteableRegistry writableRegistry() { protected NamedWriteableRegistry writableRegistry() {
List<NamedWriteableRegistry.Entry> entries = List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(
new ArrayList<>(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables()); new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables()
entries.add(new NamedWriteableRegistry.Entry(InternalAggregation.class, "throwing", InternalThrowing::new)); );
return new NamedWriteableRegistry(entries); return new NamedWriteableRegistry(entries);
} }
@ -419,7 +416,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, 3+numEmptyResponses, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, 3+numEmptyResponses, exc -> {});
if (numEmptyResponses == 0) { if (numEmptyResponses == 0) {
assertEquals(0, reductions.size()); assertEquals(0, reductions.size());
} }
@ -506,7 +504,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger(); AtomicInteger max = new AtomicInteger();
Thread[] threads = new Thread[expectedNumResults]; Thread[] threads = new Thread[expectedNumResults];
CountDownLatch latch = new CountDownLatch(expectedNumResults); CountDownLatch latch = new CountDownLatch(expectedNumResults);
@ -556,7 +555,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger(); AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults); CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) { for (int i = 0; i < expectedNumResults; i++) {
@ -597,7 +597,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
} }
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger(); AtomicInteger max = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(expectedNumResults); CountDownLatch latch = new CountDownLatch(expectedNumResults);
for (int i = 0; i < expectedNumResults; i++) { for (int i = 0; i < expectedNumResults; i++) {
@ -640,7 +641,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().size(5).from(5)); request.source(new SearchSourceBuilder().size(5).from(5));
request.setBatchedReduceSize(randomIntBetween(2, 4)); request.setBatchedReduceSize(randomIntBetween(2, 4));
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, 4, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP
, request, 4, exc -> {});
int score = 100; int score = 100;
CountDownLatch latch = new CountDownLatch(4); CountDownLatch latch = new CountDownLatch(4);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -678,7 +680,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
int size = randomIntBetween(1, 10); int size = randomIntBetween(1, 10);
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger(); AtomicInteger max = new AtomicInteger();
SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)}; SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)};
DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; DocValueFormat[] docValueFormats = {DocValueFormat.RAW};
@ -716,7 +719,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
int size = randomIntBetween(5, 10); int size = randomIntBetween(5, 10);
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, expectedNumResults, exc -> {});
SortField[] sortFields = {new SortField("field", SortField.Type.STRING)}; SortField[] sortFields = {new SortField("field", SortField.Type.STRING)};
BytesRef a = new BytesRef("a"); BytesRef a = new BytesRef("a");
BytesRef b = new BytesRef("b"); BytesRef b = new BytesRef("b");
@ -757,7 +761,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
SearchRequest request = randomSearchRequest(); SearchRequest request = randomSearchRequest();
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP,
request, expectedNumResults, exc -> {});
int maxScoreTerm = -1; int maxScoreTerm = -1;
int maxScorePhrase = -1; int maxScorePhrase = -1;
int maxScoreCompletion = -1; int maxScoreCompletion = -1;
@ -871,7 +876,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
@Override @Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) { InternalAggregations aggs, int reducePhase) {
assertEquals(numReduceListener.incrementAndGet(), reducePhase); assertEquals(numReduceListener.incrementAndGet(), reducePhase);
} }
@ -883,7 +888,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
} }
}; };
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
progressListener, request, expectedNumResults, exc -> {}); new NoopCircuitBreaker(CircuitBreaker.REQUEST), progressListener, request, expectedNumResults, exc -> {});
AtomicInteger max = new AtomicInteger(); AtomicInteger max = new AtomicInteger();
Thread[] threads = new Thread[expectedNumResults]; Thread[] threads = new Thread[expectedNumResults];
CountDownLatch latch = new CountDownLatch(expectedNumResults); CountDownLatch latch = new CountDownLatch(expectedNumResults);
@ -932,7 +937,19 @@ public class SearchPhaseControllerTests extends ESTestCase {
} }
} }
public void testPartialMergeFailure() throws InterruptedException { public void testPartialReduce() throws Exception {
for (int i = 0; i < 10; i++) {
testReduceCase(false);
}
}
public void testPartialReduceWithFailure() throws Exception {
for (int i = 0; i < 10; i++) {
testReduceCase(true);
}
}
private void testReduceCase(boolean shouldFail) throws Exception {
int expectedNumResults = randomIntBetween(20, 200); int expectedNumResults = randomIntBetween(20, 200);
int bufferSize = randomIntBetween(2, expectedNumResults - 1); int bufferSize = randomIntBetween(2, expectedNumResults - 1);
SearchRequest request = new SearchRequest(); SearchRequest request = new SearchRequest();
@ -940,11 +957,16 @@ public class SearchPhaseControllerTests extends ESTestCase {
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize); request.setBatchedReduceSize(bufferSize);
AtomicBoolean hasConsumedFailure = new AtomicBoolean(); AtomicBoolean hasConsumedFailure = new AtomicBoolean();
AssertingCircuitBreaker circuitBreaker = new AssertingCircuitBreaker(CircuitBreaker.REQUEST);
boolean shouldFailPartial = shouldFail && randomBoolean();
if (shouldFailPartial) {
circuitBreaker.shouldBreak.set(true);
}
QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor,
NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true)); circuitBreaker, SearchProgressListener.NOOP,
request, expectedNumResults, exc -> hasConsumedFailure.set(true));
CountDownLatch latch = new CountDownLatch(expectedNumResults); CountDownLatch latch = new CountDownLatch(expectedNumResults);
Thread[] threads = new Thread[expectedNumResults]; Thread[] threads = new Thread[expectedNumResults];
int failedIndex = randomIntBetween(0, expectedNumResults-1);
for (int i = 0; i < expectedNumResults; i++) { for (int i = 0; i < expectedNumResults; i++) {
final int index = i; final int index = i;
threads[index] = new Thread(() -> { threads[index] = new Thread(() -> {
@ -955,7 +977,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN),
new DocValueFormat[0]); new DocValueFormat[0]);
InternalAggregations aggs = InternalAggregations.from( InternalAggregations aggs = InternalAggregations.from(
Collections.singletonList(new InternalThrowing("test", (failedIndex == index), Collections.emptyMap()))); Collections.singletonList(new InternalMax("test", 0d, DocValueFormat.RAW, Collections.emptyMap())));
result.aggregations(aggs); result.aggregations(aggs);
result.setShardIndex(index); result.setShardIndex(index);
result.size(1); result.size(1);
@ -967,65 +989,44 @@ public class SearchPhaseControllerTests extends ESTestCase {
threads[i].join(); threads[i].join();
} }
latch.await(); latch.await();
IllegalStateException exc = expectThrows(IllegalStateException.class, () -> consumer.reduce()); if (shouldFail) {
if (exc.getMessage().contains("partial reduce")) { if (shouldFailPartial == false) {
assertTrue(hasConsumedFailure.get()); circuitBreaker.shouldBreak.set(true);
} else {
assertThat(exc.getMessage(), containsString("final reduce"));
}
}
private static class InternalThrowing extends InternalAggregation {
private final boolean shouldThrow;
protected InternalThrowing(String name, boolean shouldThrow, Map<String, Object> metadata) {
super(name, metadata);
this.shouldThrow = shouldThrow;
}
protected InternalThrowing(StreamInput in) throws IOException {
super(in);
this.shouldThrow = in.readBoolean();
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeBoolean(shouldThrow);
}
@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
if (aggregations.stream()
.map(agg -> (InternalThrowing) agg)
.anyMatch(agg -> agg.shouldThrow)) {
if (reduceContext.isFinalReduce()) {
throw new IllegalStateException("final reduce");
} else {
throw new IllegalStateException("partial reduce");
}
} }
return new InternalThrowing(name, false, metadata); CircuitBreakingException exc = expectThrows(CircuitBreakingException.class, () -> consumer.reduce());
} assertEquals(shouldFailPartial, hasConsumedFailure.get());
assertThat(exc.getMessage(), containsString("<reduce_aggs>"));
@Override circuitBreaker.shouldBreak.set(false);
protected boolean mustReduceOnSingleInternalAgg() { } else {
return true; SearchPhaseController.ReducedQueryPhase phase = consumer.reduce();
}
@Override
public Object getProperty(List<String> path) {
return null;
}
@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("not implemented");
}
@Override
public String getWriteableName() {
return "throwing";
} }
consumer.close();
assertThat(circuitBreaker.allocated, equalTo(0L));
} }
private static class AssertingCircuitBreaker extends NoopCircuitBreaker {
private final AtomicBoolean shouldBreak = new AtomicBoolean(false);
private volatile long allocated;
AssertingCircuitBreaker(String name) {
super(name);
}
@Override
public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
assert bytes >= 0;
if (shouldBreak.get()) {
throw new CircuitBreakingException(label, getDurability());
}
allocated += bytes;
return allocated;
}
@Override
public long addWithoutBreaking(long bytes) {
allocated += bytes;
return allocated;
}
}
} }

View File

@ -29,6 +29,8 @@ import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
@ -51,6 +53,7 @@ import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -144,15 +147,19 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase {
searchRequest.source().collapse(new CollapseBuilder("collapse_field")); searchRequest.source().collapse(new CollapseBuilder("collapse_field"));
} }
searchRequest.allowPartialSearchResults(false); searchRequest.allowPartialSearchResults(false);
Executor executor = EsExecutors.newDirectExecutorService();
SearchPhaseController controller = new SearchPhaseController( SearchPhaseController controller = new SearchPhaseController(
writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder()); writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder());
SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap());
QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer(searchRequest, executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST), controller, task.getProgressListener(), writableRegistry(),
shardsIter.size(), exc -> {});
SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction(logger, SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction(logger,
searchTransportService, (clusterAlias, node) -> lookup.get(node), searchTransportService, (clusterAlias, node) -> lookup.get(node),
Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
Collections.emptyMap(), Collections.emptyMap(), controller, EsExecutors.newDirectExecutorService(), searchRequest, Collections.emptyMap(), Collections.emptyMap(), controller, executor,
null, shardsIter, timeProvider, null, task, resultConsumer, searchRequest, null, shardsIter, timeProvider, null,
SearchResponse.Clusters.EMPTY, exc -> {}) { task, SearchResponse.Clusters.EMPTY) {
@Override @Override
protected SearchPhase getNextPhase(SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) { protected SearchPhase getNextPhase(SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return new SearchPhase("test") { return new SearchPhase("test") {

View File

@ -1,177 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.action.search;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESSingleNodeTestCase;
public class TransportSearchActionSingleNodeTests extends ESSingleNodeTestCase {
public void testLocalClusterAlias() {
long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("1");
indexRequest.source("field", "value");
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
"local", nowInMillis, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
SearchHit[] hits = searchResponse.getHits().getHits();
assertEquals(1, hits.length);
SearchHit hit = hits[0];
assertEquals("local", hit.getClusterAlias());
assertEquals("test", hit.getIndex());
assertEquals("1", hit.getId());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
"", nowInMillis, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
SearchHit[] hits = searchResponse.getHits().getHits();
assertEquals(1, hits.length);
SearchHit hit = hits[0];
assertEquals("", hit.getClusterAlias());
assertEquals("test", hit.getIndex());
assertEquals("1", hit.getId());
}
}
public void testAbsoluteStartMillis() {
{
IndexRequest indexRequest = new IndexRequest("test-1970.01.01");
indexRequest.id("1");
indexRequest.source("date", "1970-01-01");
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
{
IndexRequest indexRequest = new IndexRequest("test-1982.01.01");
indexRequest.id("1");
indexRequest.source("date", "1982-01-01");
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
{
SearchRequest searchRequest = new SearchRequest();
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
}
{
SearchRequest searchRequest = new SearchRequest("<test-{now/d}>");
searchRequest.indicesOptions(IndicesOptions.fromOptions(true, true, true, true));
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(0, searchResponse.getTotalShards());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
searchRequest.indices("<test-{now/d}>");
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date");
rangeQuery.gte("1970-01-01");
rangeQuery.lt("1982-01-01");
sourceBuilder.query(rangeQuery);
searchRequest.source(sourceBuilder);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
}
}
public void testFinalReduce() {
long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
{
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("1");
indexRequest.source("price", 10);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
{
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("2");
indexRequest.source("price", 100);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
}
client().admin().indices().prepareRefresh("test").get();
SearchRequest originalRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.size(0);
originalRequest.source(source);
TermsAggregationBuilder terms = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.NUMERIC);
terms.field("price");
terms.size(1);
source.aggregation(terms);
{
SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest,
Strings.EMPTY_ARRAY, "remote", nowInMillis, true);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Aggregations aggregations = searchResponse.getAggregations();
LongTerms longTerms = aggregations.get("terms");
assertEquals(1, longTerms.getBuckets().size());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest,
Strings.EMPTY_ARRAY, "remote", nowInMillis, false);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Aggregations aggregations = searchResponse.getAggregations();
LongTerms longTerms = aggregations.get("terms");
assertEquals(2, longTerms.getBuckets().size());
}
}
}

View File

@ -18,6 +18,8 @@
*/ */
package org.elasticsearch.search.aggregations; package org.elasticsearch.search.aggregations;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -44,6 +46,7 @@ import java.util.List;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.hamcrest.Matchers.equalTo;
public class InternalAggregationsTests extends ESTestCase { public class InternalAggregationsTests extends ESTestCase {
@ -126,19 +129,32 @@ public class InternalAggregationsTests extends ESTestCase {
public void testSerialization() throws Exception { public void testSerialization() throws Exception {
InternalAggregations aggregations = createTestInstance(); InternalAggregations aggregations = createTestInstance();
writeToAndReadFrom(aggregations, 0); writeToAndReadFrom(aggregations, Version.CURRENT, 0);
} }
private void writeToAndReadFrom(InternalAggregations aggregations, int iteration) throws IOException { public void testSerializedSize() throws Exception {
try (BytesStreamOutput out = new BytesStreamOutput()) { InternalAggregations aggregations = createTestInstance();
aggregations.writeTo(out); assertThat(aggregations.getSerializedSize(),
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(out.bytes().toBytesRef().bytes), registry)) { equalTo((long) serialize(aggregations, Version.CURRENT).length));
InternalAggregations deserialized = InternalAggregations.readFrom(in); }
assertEquals(aggregations.aggregations, deserialized.aggregations);
if (iteration < 2) { private void writeToAndReadFrom(InternalAggregations aggregations, Version version, int iteration) throws IOException {
writeToAndReadFrom(deserialized, iteration + 1); BytesRef serializedAggs = serialize(aggregations, version);
} try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(serializedAggs.bytes), registry)) {
in.setVersion(version);
InternalAggregations deserialized = InternalAggregations.readFrom(in);
assertEquals(aggregations.aggregations, deserialized.aggregations);
if (iteration < 2) {
writeToAndReadFrom(deserialized, version, iteration + 1);
} }
} }
} }
private BytesRef serialize(InternalAggregations aggs, Version version) throws IOException {
try (BytesStreamOutput out = new BytesStreamOutput()) {
out.setVersion(version);
aggs.writeTo(out);
return out.bytes().toBytesRef();
}
}
} }

View File

@ -1616,7 +1616,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
SearchPhaseController searchPhaseController = new SearchPhaseController( SearchPhaseController searchPhaseController = new SearchPhaseController(
writableRegistry(), searchService::aggReduceContextBuilder); writableRegistry(), searchService::aggReduceContextBuilder);
actions.put(SearchAction.INSTANCE, actions.put(SearchAction.INSTANCE,
new TransportSearchAction(client, threadPool, transportService, searchService, new TransportSearchAction(client, threadPool, new NoneCircuitBreakerService(), transportService, searchService,
searchTransportService, searchPhaseController, clusterService, searchTransportService, searchPhaseController, clusterService,
actionFilters, indexNameExpressionResolver, namedWriteableRegistry)); actionFilters, indexNameExpressionResolver, namedWriteableRegistry));
actions.put(RestoreSnapshotAction.INSTANCE, actions.put(RestoreSnapshotAction.INSTANCE,

View File

@ -20,7 +20,6 @@ import org.elasticsearch.action.search.SearchShard;
import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.action.search.SearchTask;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregation;
@ -391,7 +390,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
@Override @Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
DelayableWriteable.Serialized<InternalAggregations> aggregations, int reducePhase) { InternalAggregations aggregations, int reducePhase) {
// best effort to cancel expired tasks // best effort to cancel expired tasks
checkCancellation(); checkCancellation();
// The way that the MutableSearchResponse will build the aggs. // The way that the MutableSearchResponse will build the aggs.
@ -401,16 +400,15 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
reducedAggs = () -> null; reducedAggs = () -> null;
} else { } else {
/* /*
* Keep a reference to the serialized form of the partially * Keep a reference to the partially reduced aggs and reduce it on the fly when someone asks
* reduced aggs and reduce it on the fly when someone asks
* for it. It's important that we wait until someone needs * for it. It's important that we wait until someone needs
* the result so we don't perform the final reduce only to * the result so we don't perform the final reduce only to
* throw it away. And it is important that we keep the reference * throw it away. And it is important that we keep the reference
* to the serialized aggregations because SearchPhaseController * to the aggregations because SearchPhaseController
* *already* has that reference so we're not creating more garbage. * *already* has that reference so we're not creating more garbage.
*/ */
reducedAggs = () -> reducedAggs = () ->
InternalAggregations.topLevelReduce(singletonList(aggregations.expand()), aggReduceContextSupplier.get()); InternalAggregations.topLevelReduce(singletonList(aggregations), aggReduceContextSupplier.get());
} }
searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase); searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase);
} }

View File

@ -16,8 +16,6 @@ import org.elasticsearch.action.search.SearchShard;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
@ -25,7 +23,6 @@ import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.BucketOrder; import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -155,56 +152,14 @@ public class AsyncSearchTaskTests extends ESTestCase {
latch.await(); latch.await();
} }
public void testGetResponseFailureDuringReduction() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask();
task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(),
SearchResponse.Clusters.EMPTY, false);
InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true),
BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0)));
//providing an empty named writeable registry will make the expansion fail, hence the delayed reduction will fail too
//causing an exception when executing getResponse as part of the completion listener callback
DelayableWriteable.Serialized<InternalAggregations> serializedAggs = DelayableWriteable.referencing(aggs)
.asSerialized(InternalAggregations::readFrom, new NamedWriteableRegistry(Collections.emptyList()));
task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO),
serializedAggs, 1);
AtomicReference<AsyncSearchResponse> response = new AtomicReference<>();
CountDownLatch latch = new CountDownLatch(1);
task.addCompletionListener(new ActionListener<AsyncSearchResponse>() {
@Override
public void onResponse(AsyncSearchResponse asyncSearchResponse) {
assertTrue(response.compareAndSet(null, asyncSearchResponse));
latch.countDown();
}
@Override
public void onFailure(Exception e) {
throw new AssertionError("onFailure should not be called");
}
}, TimeValue.timeValueMillis(10L));
assertTrue(latch.await(1, TimeUnit.SECONDS));
assertNotNull(response.get().getSearchResponse());
assertEquals(0, response.get().getSearchResponse().getTotalShards());
assertEquals(0, response.get().getSearchResponse().getSuccessfulShards());
assertEquals(0, response.get().getSearchResponse().getFailedShards());
assertThat(response.get().getFailure(), instanceOf(ElasticsearchException.class));
assertEquals("Async search: error while reducing partial results", response.get().getFailure().getMessage());
assertThat(response.get().getFailure().getCause(), instanceOf(IllegalArgumentException.class));
assertEquals("Unknown NamedWriteable category [" + InternalAggregation.class.getName() + "]",
response.get().getFailure().getCause().getMessage());
}
public void testWithFailureAndGetResponseFailureDuringReduction() throws InterruptedException { public void testWithFailureAndGetResponseFailureDuringReduction() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask(); AsyncSearchTask task = createAsyncSearchTask();
task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(), task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(),
SearchResponse.Clusters.EMPTY, false); SearchResponse.Clusters.EMPTY, false);
InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true), InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true),
BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0))); BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0)));
//providing an empty named writeable registry will make the expansion fail, hence the delayed reduction will fail too
//causing an exception when executing getResponse as part of the completion listener callback
DelayableWriteable.Serialized<InternalAggregations> serializedAggs = DelayableWriteable.referencing(aggs)
.asSerialized(InternalAggregations::readFrom, new NamedWriteableRegistry(Collections.emptyList()));
task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO),
serializedAggs, 1); aggs, 1);
task.getSearchProgressActionListener().onFailure(new CircuitBreakingException("boom", CircuitBreaker.Durability.TRANSIENT)); task.getSearchProgressActionListener().onFailure(new CircuitBreakingException("boom", CircuitBreaker.Durability.TRANSIENT));
AtomicReference<AsyncSearchResponse> response = new AtomicReference<>(); AtomicReference<AsyncSearchResponse> response = new AtomicReference<>();
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
@ -229,9 +184,6 @@ public class AsyncSearchTaskTests extends ESTestCase {
Exception failure = asyncSearchResponse.getFailure(); Exception failure = asyncSearchResponse.getFailure();
assertThat(failure, instanceOf(ElasticsearchException.class)); assertThat(failure, instanceOf(ElasticsearchException.class));
assertEquals("Async search: error while reducing partial results", failure.getMessage()); assertEquals("Async search: error while reducing partial results", failure.getMessage());
assertThat(failure.getCause(), instanceOf(IllegalArgumentException.class));
assertEquals("Unknown NamedWriteable category [" + InternalAggregation.class.getName() +
"]", failure.getCause().getMessage());
assertEquals(1, failure.getSuppressed().length); assertEquals(1, failure.getSuppressed().length);
assertThat(failure.getSuppressed()[0], instanceOf(ElasticsearchException.class)); assertThat(failure.getSuppressed()[0], instanceOf(ElasticsearchException.class));
assertEquals("error while executing search", failure.getSuppressed()[0].getMessage()); assertEquals("error while executing search", failure.getSuppressed()[0].getMessage());