Merge pull request #10411 from jpountz/fix/multi_level_breadth_first

Aggregations: Fix multi-level breadth-first aggregations.

Close #10411
This commit is contained in:
Adrien Grand 2015-04-09 12:09:48 +02:00
commit 2a844fc457
9 changed files with 122 additions and 60 deletions

View File

@ -92,6 +92,7 @@ abstract class QueryCollector extends SimpleCollector {
context.aggregations().aggregators(aggregators); context.aggregations().aggregators(aggregators);
} }
aggregatorCollector = BucketCollector.wrap(aggregatorCollectors); aggregatorCollector = BucketCollector.wrap(aggregatorCollectors);
aggregatorCollector.preCollection();
} }
public void postMatch(int doc) throws IOException { public void postMatch(int doc) throws IOException {

View File

@ -76,18 +76,20 @@ public class AggregationPhase implements SearchPhase {
Aggregator[] aggregators; Aggregator[] aggregators;
try { try {
aggregators = context.aggregations().factories().createTopLevelAggregators(aggregationContext); aggregators = context.aggregations().factories().createTopLevelAggregators(aggregationContext);
for (int i = 0; i < aggregators.length; i++) {
if (aggregators[i] instanceof GlobalAggregator == false) {
collectors.add(aggregators[i]);
}
}
context.aggregations().aggregators(aggregators);
if (!collectors.isEmpty()) {
final BucketCollector collector = BucketCollector.wrap(collectors);
collector.preCollection();
context.searcher().queryCollectors().put(AggregationPhase.class, collector);
}
} catch (IOException e) { } catch (IOException e) {
throw new AggregationInitializationException("Could not initialize aggregators", e); throw new AggregationInitializationException("Could not initialize aggregators", e);
} }
for (int i = 0; i < aggregators.length; i++) {
if (aggregators[i] instanceof GlobalAggregator == false) {
collectors.add(aggregators[i]);
}
}
context.aggregations().aggregators(aggregators);
if (!collectors.isEmpty()) {
context.searcher().queryCollectors().put(AggregationPhase.class, (BucketCollector.wrap(collectors)));
}
} }
} }
@ -113,14 +115,15 @@ public class AggregationPhase implements SearchPhase {
// optimize the global collector based execution // optimize the global collector based execution
if (!globals.isEmpty()) { if (!globals.isEmpty()) {
BucketCollector collector = BucketCollector.wrap(globals); BucketCollector globalsCollector = BucketCollector.wrap(globals);
Query query = new ConstantScoreQuery(Queries.MATCH_ALL_FILTER); Query query = new ConstantScoreQuery(Queries.MATCH_ALL_FILTER);
Filter searchFilter = context.searchFilter(context.types()); Filter searchFilter = context.searchFilter(context.types());
if (searchFilter != null) { if (searchFilter != null) {
query = new FilteredQuery(query, searchFilter); query = new FilteredQuery(query, searchFilter);
} }
try { try {
context.searcher().search(query, collector); globalsCollector.preCollection();
context.searcher().search(query, globalsCollector);
} catch (Exception e) { } catch (Exception e) {
throw new QueryPhaseExecutionException(context, "Failed to execute global aggregators", e); throw new QueryPhaseExecutionException(context, "Failed to execute global aggregators", e);
} }

View File

@ -99,7 +99,12 @@ public abstract class AggregatorBase extends Aggregator {
*/ */
@Override @Override
public boolean needsScores() { public boolean needsScores() {
return collectableSubAggregators.needsScores(); for (Aggregator agg : subAggregators) {
if (agg.needsScores()) {
return true;
}
}
return false;
} }
public Map<String, Object> metaData() { public Map<String, Object> metaData() {
@ -145,6 +150,7 @@ public abstract class AggregatorBase extends Aggregator {
} }
collectableSubAggregators = BucketCollector.wrap(collectors); collectableSubAggregators = BucketCollector.wrap(collectors);
doPreCollection(); doPreCollection();
collectableSubAggregators.preCollection();
} }
/** /**

View File

@ -44,15 +44,6 @@ public class AggregatorFactories {
this.factories = factories; this.factories = factories;
} }
private static Aggregator createAndRegisterContextAware(AggregationContext context, AggregatorFactory factory, Aggregator parent, boolean collectsFromSingleBucket) throws IOException {
final Aggregator aggregator = factory.create(context, parent, collectsFromSingleBucket);
// Once the aggregator is fully constructed perform any initialisation -
// can't do everything in constructors if Aggregator base class needs
// to delegate to subclasses as part of construction.
aggregator.preCollection();
return aggregator;
}
/** /**
* Create all aggregators so that they can be consumed with multiple buckets. * Create all aggregators so that they can be consumed with multiple buckets.
*/ */
@ -64,7 +55,7 @@ public class AggregatorFactories {
// propagate the fact that only bucket 0 will be collected with single-bucket // propagate the fact that only bucket 0 will be collected with single-bucket
// aggs // aggs
final boolean collectsFromSingleBucket = false; final boolean collectsFromSingleBucket = false;
aggregators[i] = createAndRegisterContextAware(parent.context(), factories[i], parent, collectsFromSingleBucket); aggregators[i] = factories[i].create(parent.context(), parent, collectsFromSingleBucket);
} }
return aggregators; return aggregators;
} }
@ -75,7 +66,7 @@ public class AggregatorFactories {
for (int i = 0; i < factories.length; i++) { for (int i = 0; i < factories.length; i++) {
// top-level aggs only get called with bucket 0 // top-level aggs only get called with bucket 0
final boolean collectsFromSingleBucket = true; final boolean collectsFromSingleBucket = true;
aggregators[i] = createAndRegisterContextAware(ctx, factories[i], null, collectsFromSingleBucket); aggregators[i] = factories[i].create(ctx, null, collectsFromSingleBucket);
} }
return aggregators; return aggregators;
} }

View File

@ -73,7 +73,7 @@ public final class DeferringBucketCollector extends BucketCollector {
if (collector == null) { if (collector == null) {
throw new ElasticsearchIllegalStateException(); throw new ElasticsearchIllegalStateException();
} }
return collector.needsScores(); return false;
} }
/** Set the deferred collectors. */ /** Set the deferred collectors. */
@ -138,6 +138,9 @@ public final class DeferringBucketCollector extends BucketCollector {
this.selectedBuckets = hash; this.selectedBuckets = hash;
collector.preCollection(); collector.preCollection();
if (collector.needsScores()) {
throw new ElasticsearchIllegalStateException("Cannot defer if scores are needed");
}
for (Entry entry : entries) { for (Entry entry : entries) {
final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context); final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context);

View File

@ -43,7 +43,7 @@ public abstract class TermsAggregator extends BucketsAggregator {
private Explicit<Long> shardMinDocCount; private Explicit<Long> shardMinDocCount;
private Explicit<Integer> requiredSize; private Explicit<Integer> requiredSize;
private Explicit<Integer> shardSize; private Explicit<Integer> shardSize;
public BucketCountThresholds(long minDocCount, long shardMinDocCount, int requiredSize, int shardSize) { public BucketCountThresholds(long minDocCount, long shardMinDocCount, int requiredSize, int shardSize) {
this.minDocCount = new Explicit<>(minDocCount, false); this.minDocCount = new Explicit<>(minDocCount, false);
this.shardMinDocCount = new Explicit<>(shardMinDocCount, false); this.shardMinDocCount = new Explicit<>(shardMinDocCount, false);
@ -157,7 +157,9 @@ public abstract class TermsAggregator extends BucketsAggregator {
@Override @Override
protected boolean shouldDefer(Aggregator aggregator) { protected boolean shouldDefer(Aggregator aggregator) {
return (collectMode == SubAggCollectionMode.BREADTH_FIRST) && (!aggsUsedForSorting.contains(aggregator)); return collectMode == SubAggCollectionMode.BREADTH_FIRST
&& aggregator.needsScores() == false
&& !aggsUsedForSorting.contains(aggregator);
} }
} }

View File

@ -42,7 +42,10 @@ import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory
import org.elasticsearch.search.aggregations.metrics.sum.Sum; import org.elasticsearch.search.aggregations.metrics.sum.Sum;
import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.elasticsearch.test.ElasticsearchIntegrationTest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.List; import java.util.List;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
@ -67,7 +70,7 @@ import static org.hamcrest.core.IsNull.notNullValue;
* the growth of dynamic arrays is tested. * the growth of dynamic arrays is tested.
*/ */
@Slow @Slow
public class RandomTests extends ElasticsearchIntegrationTest { public class EquivalenceTests extends ElasticsearchIntegrationTest {
// Make sure that unordered, reversed, disjoint and/or overlapping ranges are supported // Make sure that unordered, reversed, disjoint and/or overlapping ranges are supported
// Duel with filters // Duel with filters
@ -380,4 +383,56 @@ public class RandomTests extends ElasticsearchIntegrationTest {
assertEquals(value >= 6 ? value : 0, sum.getValue(), 0d); assertEquals(value >= 6 ? value : 0, sum.getValue(), 0d);
} }
private void assertEquals(Terms t1, Terms t2) {
List<Terms.Bucket> t1Buckets = t1.getBuckets();
List<Terms.Bucket> t2Buckets = t1.getBuckets();
assertEquals(t1Buckets.size(), t2Buckets.size());
for (Iterator<Terms.Bucket> it1 = t1Buckets.iterator(), it2 = t2Buckets.iterator(); it1.hasNext(); ) {
final Terms.Bucket b1 = it1.next();
final Terms.Bucket b2 = it2.next();
assertEquals(b1.getDocCount(), b2.getDocCount());
assertEquals(b1.getKey(), b2.getKey());
}
}
public void testDuelDepthBreadthFirst() throws Exception {
createIndex("idx");
final int numDocs = randomIntBetween(100, 500);
List<IndexRequestBuilder> reqs = new ArrayList<>();
for (int i = 0; i < numDocs; ++i) {
final int v1 = randomInt(1 << randomInt(7));
final int v2 = randomInt(1 << randomInt(7));
final int v3 = randomInt(1 << randomInt(7));
reqs.add(client().prepareIndex("idx", "type").setSource("f1", v1, "f2", v2, "f3", v3));
}
indexRandom(true, reqs);
final SearchResponse r1 = client().prepareSearch("idx").addAggregation(
terms("f1").field("f1").collectMode(SubAggCollectionMode.DEPTH_FIRST)
.subAggregation(terms("f2").field("f2").collectMode(SubAggCollectionMode.DEPTH_FIRST)
.subAggregation(terms("f3").field("f3").collectMode(SubAggCollectionMode.DEPTH_FIRST)))).get();
assertSearchResponse(r1);
final SearchResponse r2 = client().prepareSearch("idx").addAggregation(
terms("f1").field("f1").collectMode(SubAggCollectionMode.BREADTH_FIRST)
.subAggregation(terms("f2").field("f2").collectMode(SubAggCollectionMode.BREADTH_FIRST)
.subAggregation(terms("f3").field("f3").collectMode(SubAggCollectionMode.BREADTH_FIRST)))).get();
assertSearchResponse(r2);
final Terms t1 = r1.getAggregations().get("f1");
final Terms t2 = r2.getAggregations().get("f1");
assertEquals(t1, t2);
for (Terms.Bucket b1 : t1.getBuckets()) {
final Terms.Bucket b2 = t2.getBucketByKey(b1.getKeyAsString());
final Terms sub1 = b1.getAggregations().get("f2");
final Terms sub2 = b2.getAggregations().get("f2");
assertEquals(sub1, sub2);
for (Terms.Bucket subB1 : sub1.getBuckets()) {
final Terms.Bucket subB2 = sub2.getBucketByKey(subB1.getKeyAsString());
final Terms subSub1 = subB1.getAggregations().get("f3");
final Terms subSub2 = subB2.getAggregations().get("f3");
assertEquals(subSub1, subSub2);
}
}
}
} }

View File

@ -46,7 +46,6 @@ import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.common.xcontent.XContentFactory.smileBuilder; import static org.elasticsearch.common.xcontent.XContentFactory.smileBuilder;
@ -271,6 +270,38 @@ public class TopHitsTests extends ElasticsearchIntegrationTest {
} }
} }
@Test
public void testBreadthFirst() throws Exception {
// breadth_first will be ignored since we need scores
SearchResponse response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.collectMode(SubAggCollectionMode.BREADTH_FIRST)
.field(TERMS_AGGS_FIELD)
.subAggregation(topHits("hits").setSize(3))
).get();
assertSearchResponse(response);
Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
assertThat(terms.getName(), equalTo("terms"));
assertThat(terms.getBuckets().size(), equalTo(5));
for (int i = 0; i < 5; i++) {
Terms.Bucket bucket = terms.getBucketByKey("val" + i);
assertThat(bucket, notNullValue());
assertThat(key(bucket), equalTo("val" + i));
assertThat(bucket.getDocCount(), equalTo(10l));
TopHits topHits = bucket.getAggregations().get("hits");
SearchHits hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(3));
assertThat(hits.getAt(0).sourceAsMap().size(), equalTo(4));
}
}
@Test @Test
public void testBasics_getProperty() throws Exception { public void testBasics_getProperty() throws Exception {
SearchResponse searchResponse = client().prepareSearch("idx").setQuery(matchAllQuery()) SearchResponse searchResponse = client().prepareSearch("idx").setQuery(matchAllQuery())
@ -531,37 +562,6 @@ public class TopHitsTests extends ElasticsearchIntegrationTest {
assertThat(e.getMessage(), containsString("Aggregator [top_tags_hits] of type [top_hits] cannot accept sub-aggregations")); assertThat(e.getMessage(), containsString("Aggregator [top_tags_hits] of type [top_hits] cannot accept sub-aggregations"));
} }
} }
@Test
public void testFailDeferredOnlyWhenScorerIsUsed() throws Exception {
// No track_scores or score based sort defined in top_hits agg, so don't fail:
SearchResponse response = client().prepareSearch("idx")
.setTypes("type")
.addAggregation(
terms("terms").executionHint(randomExecutionHint()).field(TERMS_AGGS_FIELD)
.collectMode(SubAggCollectionMode.BREADTH_FIRST)
.subAggregation(topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))))
.get();
assertSearchResponse(response);
// Score based, so fail with deferred aggs:
try {
client().prepareSearch("idx")
.setTypes("type")
.addAggregation(
terms("terms").executionHint(randomExecutionHint()).field(TERMS_AGGS_FIELD)
.collectMode(SubAggCollectionMode.BREADTH_FIRST)
.subAggregation(topHits("hits")))
.get();
fail();
} catch (Exception e) {
// It is considered a parse failure if the search request asks for top_hits
// under an aggregation with collect_mode set to breadth_first as this would
// require the buffering of scores alongside each document ID and that is a
// a RAM cost we are not willing to pay
assertThat(e.getMessage(), containsString("ElasticsearchIllegalStateException"));
}
}
@Test @Test
public void testEmptyIndex() throws Exception { public void testEmptyIndex() throws Exception {

View File

@ -125,6 +125,7 @@ public class NestedAggregatorTest extends ElasticsearchSingleNodeLuceneTestCase
searchContext.aggregations(new SearchContextAggregations(factories)); searchContext.aggregations(new SearchContextAggregations(factories));
Aggregator[] aggs = factories.createTopLevelAggregators(context); Aggregator[] aggs = factories.createTopLevelAggregators(context);
BucketCollector collector = BucketCollector.wrap(Arrays.asList(aggs)); BucketCollector collector = BucketCollector.wrap(Arrays.asList(aggs));
collector.preCollection();
// A regular search always exclude nested docs, so we use NonNestedDocsFilter.INSTANCE here (otherwise MatchAllDocsQuery would be sufficient) // A regular search always exclude nested docs, so we use NonNestedDocsFilter.INSTANCE here (otherwise MatchAllDocsQuery would be sufficient)
// We exclude root doc with uid type#2, this will trigger the bug if we don't reset the root doc when we process a new segment, because // We exclude root doc with uid type#2, this will trigger the bug if we don't reset the root doc when we process a new segment, because
// root doc type#3 and root doc type#1 have the same segment docid // root doc type#3 and root doc type#1 have the same segment docid