Fix global aggregation that requires breadth first and scores (#27942)
* Fix global aggregation that requires breadth first and scores This change fixes the deferring collector when it is executed in a global context with a sub collector thats requires to access scores (e.g. top_hits aggregation). The deferring collector replays the best buckets for each document and re-executes the original query if scores are needed. When executed in a global context, the query to replay is a simple match_all query and not the original query. Closes #22321 Closes #27928
This commit is contained in:
parent
480aeb7eb7
commit
cb783bcb57
|
@ -21,6 +21,8 @@ package org.elasticsearch.search.aggregations.bucket;
|
||||||
|
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.Scorer;
|
import org.apache.lucene.search.Scorer;
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.util.packed.PackedInts;
|
import org.apache.lucene.util.packed.PackedInts;
|
||||||
|
@ -59,6 +61,7 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
|
||||||
final List<Entry> entries = new ArrayList<>();
|
final List<Entry> entries = new ArrayList<>();
|
||||||
BucketCollector collector;
|
BucketCollector collector;
|
||||||
final SearchContext searchContext;
|
final SearchContext searchContext;
|
||||||
|
final boolean isGlobal;
|
||||||
LeafReaderContext context;
|
LeafReaderContext context;
|
||||||
PackedLongValues.Builder docDeltas;
|
PackedLongValues.Builder docDeltas;
|
||||||
PackedLongValues.Builder buckets;
|
PackedLongValues.Builder buckets;
|
||||||
|
@ -66,9 +69,14 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
|
||||||
boolean finished = false;
|
boolean finished = false;
|
||||||
LongHash selectedBuckets;
|
LongHash selectedBuckets;
|
||||||
|
|
||||||
/** Sole constructor. */
|
/**
|
||||||
public BestBucketsDeferringCollector(SearchContext context) {
|
* Sole constructor.
|
||||||
|
* @param context The search context
|
||||||
|
* @param isGlobal Whether this collector visits all documents (global context)
|
||||||
|
*/
|
||||||
|
public BestBucketsDeferringCollector(SearchContext context, boolean isGlobal) {
|
||||||
this.searchContext = context;
|
this.searchContext = context;
|
||||||
|
this.isGlobal = isGlobal;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -144,11 +152,11 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
|
||||||
}
|
}
|
||||||
this.selectedBuckets = hash;
|
this.selectedBuckets = hash;
|
||||||
|
|
||||||
boolean needsScores = collector.needsScores();
|
boolean needsScores = needsScores();
|
||||||
Weight weight = null;
|
Weight weight = null;
|
||||||
if (needsScores) {
|
if (needsScores) {
|
||||||
weight = searchContext.searcher()
|
Query query = isGlobal ? new MatchAllDocsQuery() : searchContext.query();
|
||||||
.createNormalizedWeight(searchContext.query(), true);
|
weight = searchContext.searcher().createNormalizedWeight(query, true);
|
||||||
}
|
}
|
||||||
for (Entry entry : entries) {
|
for (Entry entry : entries) {
|
||||||
final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context);
|
final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context);
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.bucket;
|
||||||
import org.elasticsearch.search.aggregations.Aggregator;
|
import org.elasticsearch.search.aggregations.Aggregator;
|
||||||
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
||||||
import org.elasticsearch.search.aggregations.BucketCollector;
|
import org.elasticsearch.search.aggregations.BucketCollector;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||||
import org.elasticsearch.search.internal.SearchContext;
|
import org.elasticsearch.search.internal.SearchContext;
|
||||||
|
|
||||||
|
@ -61,10 +62,20 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator {
|
||||||
collectableSubAggregators = BucketCollector.wrap(collectors);
|
collectableSubAggregators = BucketCollector.wrap(collectors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean descendsFromGlobalAggregator(Aggregator parent) {
|
||||||
|
while (parent != null) {
|
||||||
|
if (parent.getClass() == GlobalAggregator.class) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
parent = parent.parent();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public DeferringBucketCollector getDeferringCollector() {
|
public DeferringBucketCollector getDeferringCollector() {
|
||||||
// Default impl is a collector that selects the best buckets
|
// Default impl is a collector that selects the best buckets
|
||||||
// but an alternative defer policy may be based on best docs.
|
// but an alternative defer policy may be based on best docs.
|
||||||
return new BestBucketsDeferringCollector(context());
|
return new BestBucketsDeferringCollector(context(), descendsFromGlobalAggregator(parent()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -74,7 +85,7 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator {
|
||||||
* recording of all doc/bucketIds from the first pass and then the sub class
|
* recording of all doc/bucketIds from the first pass and then the sub class
|
||||||
* should call {@link #runDeferredCollections(long...)} for the selected set
|
* should call {@link #runDeferredCollections(long...)} for the selected set
|
||||||
* of buckets that survive the pruning.
|
* of buckets that survive the pruning.
|
||||||
*
|
*
|
||||||
* @param aggregator
|
* @param aggregator
|
||||||
* the child aggregator
|
* the child aggregator
|
||||||
* @return true if the aggregator should be deferred until a first pass at
|
* @return true if the aggregator should be deferred until a first pass at
|
||||||
|
|
|
@ -27,6 +27,8 @@ import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.RandomIndexWriter;
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
@ -41,6 +43,8 @@ import java.util.Collections;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
||||||
|
|
||||||
public void testReplay() throws Exception {
|
public void testReplay() throws Exception {
|
||||||
|
@ -59,10 +63,17 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
||||||
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
|
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
|
||||||
|
|
||||||
TermQuery termQuery = new TermQuery(new Term("field", String.valueOf(randomInt(maxNumValues))));
|
TermQuery termQuery = new TermQuery(new Term("field", String.valueOf(randomInt(maxNumValues))));
|
||||||
|
Query rewrittenQuery = indexSearcher.rewrite(termQuery);
|
||||||
TopDocs topDocs = indexSearcher.search(termQuery, numDocs);
|
TopDocs topDocs = indexSearcher.search(termQuery, numDocs);
|
||||||
|
|
||||||
SearchContext searchContext = createSearchContext(indexSearcher, createIndexSettings());
|
SearchContext searchContext = createSearchContext(indexSearcher, createIndexSettings());
|
||||||
BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext);
|
when(searchContext.query()).thenReturn(rewrittenQuery);
|
||||||
|
BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext, false) {
|
||||||
|
@Override
|
||||||
|
public boolean needsScores() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
Set<Integer> deferredCollectedDocIds = new HashSet<>();
|
Set<Integer> deferredCollectedDocIds = new HashSet<>();
|
||||||
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
|
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
|
||||||
collector.preCollection();
|
collector.preCollection();
|
||||||
|
@ -70,6 +81,20 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
||||||
collector.postCollection();
|
collector.postCollection();
|
||||||
collector.replay(0);
|
collector.replay(0);
|
||||||
|
|
||||||
|
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
|
||||||
|
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||||
|
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));
|
||||||
|
}
|
||||||
|
|
||||||
|
topDocs = indexSearcher.search(new MatchAllDocsQuery(), numDocs);
|
||||||
|
collector = new BestBucketsDeferringCollector(searchContext, true);
|
||||||
|
deferredCollectedDocIds = new HashSet<>();
|
||||||
|
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
|
||||||
|
collector.preCollection();
|
||||||
|
indexSearcher.search(new MatchAllDocsQuery(), collector);
|
||||||
|
collector.postCollection();
|
||||||
|
collector.replay(0);
|
||||||
|
|
||||||
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
|
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
|
||||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||||
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));
|
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));
|
||||||
|
|
|
@ -46,14 +46,21 @@ import org.elasticsearch.index.mapper.MappedFieldType;
|
||||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||||
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
import org.elasticsearch.search.aggregations.Aggregator;
|
import org.elasticsearch.search.aggregations.Aggregator;
|
||||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||||
import org.elasticsearch.search.aggregations.BucketOrder;
|
import org.elasticsearch.search.aggregations.BucketOrder;
|
||||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||||
|
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
|
||||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||||
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
|
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.tophits.InternalTopHits;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.support.ValueType;
|
import org.elasticsearch.search.aggregations.support.ValueType;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -67,6 +74,8 @@ import java.util.Map;
|
||||||
import java.util.function.BiFunction;
|
import java.util.function.BiFunction;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.greaterThan;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
|
||||||
public class TermsAggregatorTests extends AggregatorTestCase {
|
public class TermsAggregatorTests extends AggregatorTestCase {
|
||||||
|
@ -933,6 +942,63 @@ public class TermsAggregatorTests extends AggregatorTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGlobalAggregationWithScore() throws IOException {
|
||||||
|
try (Directory directory = newDirectory()) {
|
||||||
|
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
|
||||||
|
Document document = new Document();
|
||||||
|
document.add(new SortedDocValuesField("keyword", new BytesRef("a")));
|
||||||
|
indexWriter.addDocument(document);
|
||||||
|
document = new Document();
|
||||||
|
document.add(new SortedDocValuesField("keyword", new BytesRef("c")));
|
||||||
|
indexWriter.addDocument(document);
|
||||||
|
document = new Document();
|
||||||
|
document.add(new SortedDocValuesField("keyword", new BytesRef("e")));
|
||||||
|
indexWriter.addDocument(document);
|
||||||
|
try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) {
|
||||||
|
IndexSearcher indexSearcher = newIndexSearcher(indexReader);
|
||||||
|
String executionHint = randomFrom(TermsAggregatorFactory.ExecutionMode.values()).toString();
|
||||||
|
Aggregator.SubAggCollectionMode collectionMode = randomFrom(Aggregator.SubAggCollectionMode.values());
|
||||||
|
GlobalAggregationBuilder globalBuilder = new GlobalAggregationBuilder("global")
|
||||||
|
.subAggregation(
|
||||||
|
new TermsAggregationBuilder("terms", ValueType.STRING)
|
||||||
|
.executionHint(executionHint)
|
||||||
|
.collectMode(collectionMode)
|
||||||
|
.field("keyword")
|
||||||
|
.order(BucketOrder.key(true))
|
||||||
|
.subAggregation(
|
||||||
|
new TermsAggregationBuilder("sub_terms", ValueType.STRING)
|
||||||
|
.executionHint(executionHint)
|
||||||
|
.collectMode(collectionMode)
|
||||||
|
.field("keyword").order(BucketOrder.key(true))
|
||||||
|
.subAggregation(
|
||||||
|
new TopHitsAggregationBuilder("top_hits")
|
||||||
|
.storedField("_none_")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType();
|
||||||
|
fieldType.setName("keyword");
|
||||||
|
fieldType.setHasDocValues(true);
|
||||||
|
|
||||||
|
InternalGlobal result = searchAndReduce(indexSearcher, new MatchAllDocsQuery(), globalBuilder, fieldType);
|
||||||
|
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
|
||||||
|
assertThat(terms.getBuckets().size(), equalTo(3));
|
||||||
|
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
|
||||||
|
InternalMultiBucketAggregation<?, ?> subTerms = bucket.getAggregations().get("sub_terms");
|
||||||
|
assertThat(subTerms.getBuckets().size(), equalTo(1));
|
||||||
|
MultiBucketsAggregation.Bucket subBucket = subTerms.getBuckets().get(0);
|
||||||
|
InternalTopHits topHits = subBucket.getAggregations().get("top_hits");
|
||||||
|
assertThat(topHits.getHits().getHits().length, equalTo(1));
|
||||||
|
for (SearchHit hit : topHits.getHits()) {
|
||||||
|
assertThat(hit.getScore(), greaterThan(0f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private IndexReader createIndexWithLongs() throws IOException {
|
private IndexReader createIndexWithLongs() throws IOException {
|
||||||
Directory directory = newDirectory();
|
Directory directory = newDirectory();
|
||||||
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
|
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
|
||||||
|
|
Loading…
Reference in New Issue