Fix global aggregation that requires breadth first and scores (#27942)

* Fix global aggregation that requires breadth first and scores

This change fixes the deferring collector when it is executed in a global context
with a sub collector thats requires to access scores (e.g. top_hits aggregation).
The deferring collector replays the best buckets for each document and re-executes the original query
if scores are needed. When executed in a global context, the query to replay is a simple match_all
 query and not the original query.

Closes #22321
Closes #27928
This commit is contained in:
Jim Ferenczi 2018-01-05 11:41:36 +01:00 committed by GitHub
parent 480aeb7eb7
commit cb783bcb57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 8 deletions

View File

@ -21,6 +21,8 @@ package org.elasticsearch.search.aggregations.bucket;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.packed.PackedInts;
@ -59,6 +61,7 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
final List<Entry> entries = new ArrayList<>();
BucketCollector collector;
final SearchContext searchContext;
final boolean isGlobal;
LeafReaderContext context;
PackedLongValues.Builder docDeltas;
PackedLongValues.Builder buckets;
@ -66,9 +69,14 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
boolean finished = false;
LongHash selectedBuckets;
/** Sole constructor. */
public BestBucketsDeferringCollector(SearchContext context) {
/**
* Sole constructor.
* @param context The search context
* @param isGlobal Whether this collector visits all documents (global context)
*/
public BestBucketsDeferringCollector(SearchContext context, boolean isGlobal) {
this.searchContext = context;
this.isGlobal = isGlobal;
}
@Override
@ -144,11 +152,11 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
}
this.selectedBuckets = hash;
boolean needsScores = collector.needsScores();
boolean needsScores = needsScores();
Weight weight = null;
if (needsScores) {
weight = searchContext.searcher()
.createNormalizedWeight(searchContext.query(), true);
Query query = isGlobal ? new MatchAllDocsQuery() : searchContext.query();
weight = searchContext.searcher().createNormalizedWeight(query, true);
}
for (Entry entry : entries) {
final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context);

View File

@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.bucket;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.internal.SearchContext;
@ -61,10 +62,20 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator {
collectableSubAggregators = BucketCollector.wrap(collectors);
}
public static boolean descendsFromGlobalAggregator(Aggregator parent) {
while (parent != null) {
if (parent.getClass() == GlobalAggregator.class) {
return true;
}
parent = parent.parent();
}
return false;
}
public DeferringBucketCollector getDeferringCollector() {
// Default impl is a collector that selects the best buckets
// but an alternative defer policy may be based on best docs.
return new BestBucketsDeferringCollector(context());
return new BestBucketsDeferringCollector(context(), descendsFromGlobalAggregator(parent()));
}
/**

View File

@ -27,6 +27,8 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
@ -41,6 +43,8 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import static org.mockito.Mockito.when;
public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
public void testReplay() throws Exception {
@ -59,10 +63,17 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
TermQuery termQuery = new TermQuery(new Term("field", String.valueOf(randomInt(maxNumValues))));
Query rewrittenQuery = indexSearcher.rewrite(termQuery);
TopDocs topDocs = indexSearcher.search(termQuery, numDocs);
SearchContext searchContext = createSearchContext(indexSearcher, createIndexSettings());
BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext);
when(searchContext.query()).thenReturn(rewrittenQuery);
BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext, false) {
@Override
public boolean needsScores() {
return true;
}
};
Set<Integer> deferredCollectedDocIds = new HashSet<>();
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
collector.preCollection();
@ -70,6 +81,20 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
collector.postCollection();
collector.replay(0);
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));
}
topDocs = indexSearcher.search(new MatchAllDocsQuery(), numDocs);
collector = new BestBucketsDeferringCollector(searchContext, true);
deferredCollectedDocIds = new HashSet<>();
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
collector.preCollection();
indexSearcher.search(new MatchAllDocsQuery(), collector);
collector.postCollection();
collector.replay(0);
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));

View File

@ -46,14 +46,21 @@ import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal;
import org.elasticsearch.search.aggregations.metrics.tophits.InternalTopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import java.io.IOException;
@ -67,6 +74,8 @@ import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.instanceOf;
public class TermsAggregatorTests extends AggregatorTestCase {
@ -933,6 +942,63 @@ public class TermsAggregatorTests extends AggregatorTestCase {
}
}
public void testGlobalAggregationWithScore() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
Document document = new Document();
document.add(new SortedDocValuesField("keyword", new BytesRef("a")));
indexWriter.addDocument(document);
document = new Document();
document.add(new SortedDocValuesField("keyword", new BytesRef("c")));
indexWriter.addDocument(document);
document = new Document();
document.add(new SortedDocValuesField("keyword", new BytesRef("e")));
indexWriter.addDocument(document);
try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) {
IndexSearcher indexSearcher = newIndexSearcher(indexReader);
String executionHint = randomFrom(TermsAggregatorFactory.ExecutionMode.values()).toString();
Aggregator.SubAggCollectionMode collectionMode = randomFrom(Aggregator.SubAggCollectionMode.values());
GlobalAggregationBuilder globalBuilder = new GlobalAggregationBuilder("global")
.subAggregation(
new TermsAggregationBuilder("terms", ValueType.STRING)
.executionHint(executionHint)
.collectMode(collectionMode)
.field("keyword")
.order(BucketOrder.key(true))
.subAggregation(
new TermsAggregationBuilder("sub_terms", ValueType.STRING)
.executionHint(executionHint)
.collectMode(collectionMode)
.field("keyword").order(BucketOrder.key(true))
.subAggregation(
new TopHitsAggregationBuilder("top_hits")
.storedField("_none_")
)
)
);
MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType();
fieldType.setName("keyword");
fieldType.setHasDocValues(true);
InternalGlobal result = searchAndReduce(indexSearcher, new MatchAllDocsQuery(), globalBuilder, fieldType);
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
assertThat(terms.getBuckets().size(), equalTo(3));
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
InternalMultiBucketAggregation<?, ?> subTerms = bucket.getAggregations().get("sub_terms");
assertThat(subTerms.getBuckets().size(), equalTo(1));
MultiBucketsAggregation.Bucket subBucket = subTerms.getBuckets().get(0);
InternalTopHits topHits = subBucket.getAggregations().get("top_hits");
assertThat(topHits.getHits().getHits().length, equalTo(1));
for (SearchHit hit : topHits.getHits()) {
assertThat(hit.getScore(), greaterThan(0f));
}
}
}
}
}
}
private IndexReader createIndexWithLongs() throws IOException {
Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);