From cb783bcb5707c422602e9d9c44de1efde87c9289 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 5 Jan 2018 11:41:36 +0100 Subject: [PATCH] Fix global aggregation that requires breadth first and scores (#27942) * Fix global aggregation that requires breadth first and scores This change fixes the deferring collector when it is executed in a global context with a sub collector thats requires to access scores (e.g. top_hits aggregation). The deferring collector replays the best buckets for each document and re-executes the original query if scores are needed. When executed in a global context, the query to replay is a simple match_all query and not the original query. Closes #22321 Closes #27928 --- .../bucket/BestBucketsDeferringCollector.java | 18 +++-- .../bucket/DeferableBucketAggregator.java | 15 ++++- .../BestBucketsDeferringCollectorTests.java | 27 +++++++- .../bucket/terms/TermsAggregatorTests.java | 66 +++++++++++++++++++ 4 files changed, 118 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java b/core/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java index 61d85b80ad9..d6be0f57866 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java @@ -21,6 +21,8 @@ package org.elasticsearch.search.aggregations.bucket; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.util.packed.PackedInts; @@ -59,6 +61,7 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector { final List entries = new ArrayList<>(); BucketCollector collector; final SearchContext searchContext; + final boolean isGlobal; LeafReaderContext context; PackedLongValues.Builder docDeltas; PackedLongValues.Builder buckets; @@ -66,9 +69,14 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector { boolean finished = false; LongHash selectedBuckets; - /** Sole constructor. */ - public BestBucketsDeferringCollector(SearchContext context) { + /** + * Sole constructor. + * @param context The search context + * @param isGlobal Whether this collector visits all documents (global context) + */ + public BestBucketsDeferringCollector(SearchContext context, boolean isGlobal) { this.searchContext = context; + this.isGlobal = isGlobal; } @Override @@ -144,11 +152,11 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector { } this.selectedBuckets = hash; - boolean needsScores = collector.needsScores(); + boolean needsScores = needsScores(); Weight weight = null; if (needsScores) { - weight = searchContext.searcher() - .createNormalizedWeight(searchContext.query(), true); + Query query = isGlobal ? new MatchAllDocsQuery() : searchContext.query(); + weight = searchContext.searcher().createNormalizedWeight(query, true); } for (Entry entry : entries) { final LeafBucketCollector leafCollector = collector.getLeafCollector(entry.context); diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java index bbfcef0af40..0ff5ea12b97 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java @@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.bucket; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.BucketCollector; +import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.internal.SearchContext; @@ -61,10 +62,20 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator { collectableSubAggregators = BucketCollector.wrap(collectors); } + public static boolean descendsFromGlobalAggregator(Aggregator parent) { + while (parent != null) { + if (parent.getClass() == GlobalAggregator.class) { + return true; + } + parent = parent.parent(); + } + return false; + } + public DeferringBucketCollector getDeferringCollector() { // Default impl is a collector that selects the best buckets // but an alternative defer policy may be based on best docs. - return new BestBucketsDeferringCollector(context()); + return new BestBucketsDeferringCollector(context(), descendsFromGlobalAggregator(parent())); } /** @@ -74,7 +85,7 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator { * recording of all doc/bucketIds from the first pass and then the sub class * should call {@link #runDeferredCollections(long...)} for the selected set * of buckets that survive the pruning. - * + * * @param aggregator * the child aggregator * @return true if the aggregator should be deferred until a first pass at diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollectorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollectorTests.java index 975c1a0d466..8d60dde5834 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollectorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollectorTests.java @@ -27,6 +27,8 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.Term; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; @@ -41,6 +43,8 @@ import java.util.Collections; import java.util.HashSet; import java.util.Set; +import static org.mockito.Mockito.when; + public class BestBucketsDeferringCollectorTests extends AggregatorTestCase { public void testReplay() throws Exception { @@ -59,10 +63,17 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase { IndexSearcher indexSearcher = new IndexSearcher(indexReader); TermQuery termQuery = new TermQuery(new Term("field", String.valueOf(randomInt(maxNumValues)))); + Query rewrittenQuery = indexSearcher.rewrite(termQuery); TopDocs topDocs = indexSearcher.search(termQuery, numDocs); SearchContext searchContext = createSearchContext(indexSearcher, createIndexSettings()); - BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext); + when(searchContext.query()).thenReturn(rewrittenQuery); + BestBucketsDeferringCollector collector = new BestBucketsDeferringCollector(searchContext, false) { + @Override + public boolean needsScores() { + return true; + } + }; Set deferredCollectedDocIds = new HashSet<>(); collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds))); collector.preCollection(); @@ -70,6 +81,20 @@ public class BestBucketsDeferringCollectorTests extends AggregatorTestCase { collector.postCollection(); collector.replay(0); + assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size()); + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc)); + } + + topDocs = indexSearcher.search(new MatchAllDocsQuery(), numDocs); + collector = new BestBucketsDeferringCollector(searchContext, true); + deferredCollectedDocIds = new HashSet<>(); + collector.setDeferredCollector(Collections.singleton(bla(deferredCollectedDocIds))); + collector.preCollection(); + indexSearcher.search(new MatchAllDocsQuery(), collector); + collector.postCollection(); + collector.replay(0); + assertEquals(topDocs.scoreDocs.length, deferredCollectedDocIds.size()); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { assertTrue("expected docid [" + scoreDoc.doc + "] is missing", deferredCollectedDocIds.contains(scoreDoc.doc)); diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java index 74cc35da300..cbd58bd4acd 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java @@ -46,14 +46,21 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.BucketOrder; import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; import org.elasticsearch.search.aggregations.bucket.filter.Filter; import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal; +import org.elasticsearch.search.aggregations.metrics.tophits.InternalTopHits; +import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.support.ValueType; import java.io.IOException; @@ -67,6 +74,8 @@ import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.instanceOf; public class TermsAggregatorTests extends AggregatorTestCase { @@ -933,6 +942,63 @@ public class TermsAggregatorTests extends AggregatorTestCase { } } + public void testGlobalAggregationWithScore() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new SortedDocValuesField("keyword", new BytesRef("a"))); + indexWriter.addDocument(document); + document = new Document(); + document.add(new SortedDocValuesField("keyword", new BytesRef("c"))); + indexWriter.addDocument(document); + document = new Document(); + document.add(new SortedDocValuesField("keyword", new BytesRef("e"))); + indexWriter.addDocument(document); + try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) { + IndexSearcher indexSearcher = newIndexSearcher(indexReader); + String executionHint = randomFrom(TermsAggregatorFactory.ExecutionMode.values()).toString(); + Aggregator.SubAggCollectionMode collectionMode = randomFrom(Aggregator.SubAggCollectionMode.values()); + GlobalAggregationBuilder globalBuilder = new GlobalAggregationBuilder("global") + .subAggregation( + new TermsAggregationBuilder("terms", ValueType.STRING) + .executionHint(executionHint) + .collectMode(collectionMode) + .field("keyword") + .order(BucketOrder.key(true)) + .subAggregation( + new TermsAggregationBuilder("sub_terms", ValueType.STRING) + .executionHint(executionHint) + .collectMode(collectionMode) + .field("keyword").order(BucketOrder.key(true)) + .subAggregation( + new TopHitsAggregationBuilder("top_hits") + .storedField("_none_") + ) + ) + ); + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(); + fieldType.setName("keyword"); + fieldType.setHasDocValues(true); + + InternalGlobal result = searchAndReduce(indexSearcher, new MatchAllDocsQuery(), globalBuilder, fieldType); + InternalMultiBucketAggregation terms = result.getAggregations().get("terms"); + assertThat(terms.getBuckets().size(), equalTo(3)); + for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) { + InternalMultiBucketAggregation subTerms = bucket.getAggregations().get("sub_terms"); + assertThat(subTerms.getBuckets().size(), equalTo(1)); + MultiBucketsAggregation.Bucket subBucket = subTerms.getBuckets().get(0); + InternalTopHits topHits = subBucket.getAggregations().get("top_hits"); + assertThat(topHits.getHits().getHits().length, equalTo(1)); + for (SearchHit hit : topHits.getHits()) { + assertThat(hit.getScore(), greaterThan(0f)); + } + } + } + } + } + } + private IndexReader createIndexWithLongs() throws IOException { Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);