Fix global aggregation that requires breadth first and scores (#27942)
* Fix global aggregation that requires breadth first and scores This change fixes the deferring collector when it is executed in a global context with a sub collector thats requires to access scores (e.g. top_hits aggregation). The deferring collector replays the best buckets for each document and re-executes the original query if scores are needed. When executed in a global context, the query to replay is a simple match_all query and not the original query. Closes #22321 Closes #27928
This commit is contained in:
parent
480aeb7eb7
commit
cb783bcb57
|
@ -21,6 +21,8 @@ package org.elasticsearch.search.aggregations.bucket;
|
|||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.util.packed.PackedInts;
|
||||
|
@ -59,6 +61,7 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
|
|||
final List<Entry> entries = new ArrayList<>();
|
||||
BucketCollector collector;
|
||||
final SearchContext searchContext;
|
||||
final boolean isGlobal;
|
||||
LeafReaderContext context;
|
||||
PackedLongValues.Builder docDeltas;
|
||||
PackedLongValues.Builder buckets;
|
||||
|
@ -66,9 +69,14 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
|
|||
boolean finished = false;
|
||||
LongHash selectedBuckets;
|
||||
|
||||
/** Sole constructor. */
|
||||
public BestBucketsDeferringCollector(SearchContext context) {
|
||||
/**
|
||||
* Sole constructor.
|
||||
* @param context The search context
|
||||
* @param isGlobal Whether this collector visits all documents (global context)
|
||||
*/
|
||||
public BestBucketsDeferringCollector(SearchContext context, boolean isGlobal) {
|
||||
this.searchContext = context;
|
||||
this.isGlobal = isGlobal;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -144,11 +152,11 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
|
|||
}
|
||||
this.selectedBuckets = hash;
|
||||
|
||||
boolean needsScores = collector.needsScores();
|
||||
boolean needsScores = needsScores();
|
||||
Weight weight = null;
|
||||
if (needsScores) {
|
||||
weight = searchContext.searcher()
|
||||
.createNormalizedWeight(searchContext.query(), true);
|
||||
Query query = isGlobal ? new MatchAllDocsQuery() : searchContext.query();
|
||||
weight = searchContext.searcher().createNormalizedWeight(query, true);
|
||||
}
|
||||
for (Entry entry : entries) {
|
||||
final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context);
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.bucket;
|
|||
import org.elasticsearch.search.aggregations.Aggregator;
|
||||
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
||||
import org.elasticsearch.search.aggregations.BucketCollector;
|
||||
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.search.internal.SearchContext;
|
||||
|
||||
|
@ -61,10 +62,20 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator {
|
|||
collectableSubAggregators = BucketCollector.wrap(collectors);
|
||||
}
|
||||
|
||||
public static boolean descendsFromGlobalAggregator(Aggregator parent) {
|
||||
while (parent != null) {
|
||||
if (parent.getClass() == GlobalAggregator.class) {
|
||||
return true;
|
||||
}
|
||||
parent = parent.parent();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public DeferringBucketCollector getDeferringCollector() {
|
||||
// Default impl is a collector that selects the best buckets
|
||||
// but an alternative defer policy may be based on best docs.
|
||||
return new BestBucketsDeferringCollector(context());
|
||||
return new BestBucketsDeferringCollector(context(), descendsFromGlobalAggregator(parent()));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,6 +27,8 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
|
@ -41,6 +43,8 @@ import java.util.Collections;
|
|||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
||||
|
||||
public void testReplay() throws Exception {
|
||||
|
@ -59,10 +63,17 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
|||
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
|
||||
|
||||
TermQuery termQuery = new TermQuery(new Term("field", String.valueOf(randomInt(maxNumValues))));
|
||||
Query rewrittenQuery = indexSearcher.rewrite(termQuery);
|
||||
TopDocs topDocs = indexSearcher.search(termQuery, numDocs);
|
||||
|
||||
SearchContext searchContext = createSearchContext(indexSearcher, createIndexSettings());
|
||||
BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext);
|
||||
when(searchContext.query()).thenReturn(rewrittenQuery);
|
||||
BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext, false) {
|
||||
@Override
|
||||
public boolean needsScores() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
Set<Integer> deferredCollectedDocIds = new HashSet<>();
|
||||
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
|
||||
collector.preCollection();
|
||||
|
@ -70,6 +81,20 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase {
|
|||
collector.postCollection();
|
||||
collector.replay(0);
|
||||
|
||||
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));
|
||||
}
|
||||
|
||||
topDocs = indexSearcher.search(new MatchAllDocsQuery(), numDocs);
|
||||
collector = new BestBucketsDeferringCollector(searchContext, true);
|
||||
deferredCollectedDocIds = new HashSet<>();
|
||||
collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds)));
|
||||
collector.preCollection();
|
||||
indexSearcher.search(new MatchAllDocsQuery(), collector);
|
||||
collector.postCollection();
|
||||
collector.replay(0);
|
||||
|
||||
assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size());
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc));
|
||||
|
|
|
@ -46,14 +46,21 @@ import org.elasticsearch.index.mapper.MappedFieldType;
|
|||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregator;
|
||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||
import org.elasticsearch.search.aggregations.BucketOrder;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
|
||||
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal;
|
||||
import org.elasticsearch.search.aggregations.metrics.tophits.InternalTopHits;
|
||||
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.support.ValueType;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -67,6 +74,8 @@ import java.util.Map;
|
|||
import java.util.function.BiFunction;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class TermsAggregatorTests extends AggregatorTestCase {
|
||||
|
@ -933,6 +942,63 @@ public class TermsAggregatorTests extends AggregatorTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testGlobalAggregationWithScore() throws IOException {
|
||||
try (Directory directory = newDirectory()) {
|
||||
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
|
||||
Document document = new Document();
|
||||
document.add(new SortedDocValuesField("keyword", new BytesRef("a")));
|
||||
indexWriter.addDocument(document);
|
||||
document = new Document();
|
||||
document.add(new SortedDocValuesField("keyword", new BytesRef("c")));
|
||||
indexWriter.addDocument(document);
|
||||
document = new Document();
|
||||
document.add(new SortedDocValuesField("keyword", new BytesRef("e")));
|
||||
indexWriter.addDocument(document);
|
||||
try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) {
|
||||
IndexSearcher indexSearcher = newIndexSearcher(indexReader);
|
||||
String executionHint = randomFrom(TermsAggregatorFactory.ExecutionMode.values()).toString();
|
||||
Aggregator.SubAggCollectionMode collectionMode = randomFrom(Aggregator.SubAggCollectionMode.values());
|
||||
GlobalAggregationBuilder globalBuilder = new GlobalAggregationBuilder("global")
|
||||
.subAggregation(
|
||||
new TermsAggregationBuilder("terms", ValueType.STRING)
|
||||
.executionHint(executionHint)
|
||||
.collectMode(collectionMode)
|
||||
.field("keyword")
|
||||
.order(BucketOrder.key(true))
|
||||
.subAggregation(
|
||||
new TermsAggregationBuilder("sub_terms", ValueType.STRING)
|
||||
.executionHint(executionHint)
|
||||
.collectMode(collectionMode)
|
||||
.field("keyword").order(BucketOrder.key(true))
|
||||
.subAggregation(
|
||||
new TopHitsAggregationBuilder("top_hits")
|
||||
.storedField("_none_")
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType();
|
||||
fieldType.setName("keyword");
|
||||
fieldType.setHasDocValues(true);
|
||||
|
||||
InternalGlobal result = searchAndReduce(indexSearcher, new MatchAllDocsQuery(), globalBuilder, fieldType);
|
||||
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
|
||||
assertThat(terms.getBuckets().size(), equalTo(3));
|
||||
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
|
||||
InternalMultiBucketAggregation<?, ?> subTerms = bucket.getAggregations().get("sub_terms");
|
||||
assertThat(subTerms.getBuckets().size(), equalTo(1));
|
||||
MultiBucketsAggregation.Bucket subBucket = subTerms.getBuckets().get(0);
|
||||
InternalTopHits topHits = subBucket.getAggregations().get("top_hits");
|
||||
assertThat(topHits.getHits().getHits().length, equalTo(1));
|
||||
for (SearchHit hit : topHits.getHits()) {
|
||||
assertThat(hit.getScore(), greaterThan(0f));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private IndexReader createIndexWithLongs() throws IOException {
|
||||
Directory directory = newDirectory();
|
||||
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
|
||||
|
|
Loading…
Reference in New Issue