From 3812d3cb4334e8b64983e671d6131099b868e21e Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Tue, 31 Oct 2017 09:59:06 +0100 Subject: [PATCH] TopHitsAggregator must propagate calls to `setScorer`. (#27138) It is required in order to work correctly with bulk scorer implementations that change the scorer during the collection process. Otherwise sub collectors might call `Scorer.score()` on the wrong scorer. Closes #27131 --- .../metrics/tophits/TopHitsAggregator.java | 5 ++ .../tophits/TopHitsAggregatorTests.java | 51 +++++++++++++++++++ .../aggregations/AggregatorTestCase.java | 4 +- 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java index 0f42118683a..700acdf797a 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java @@ -20,6 +20,8 @@ package org.elasticsearch.search.aggregations.metrics.tophits; import com.carrotsearch.hppc.LongObjectHashMap; +import com.carrotsearch.hppc.cursors.ObjectCursor; + import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.LeafCollector; @@ -93,6 +95,9 @@ public class TopHitsAggregator extends MetricsAggregator { public void setScorer(Scorer scorer) throws IOException { this.scorer = scorer; super.setScorer(scorer); + for (ObjectCursor cursor : leafCollectors.values()) { + cursor.value.setScorer(scorer); + } } @Override diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorTests.java index 43686a1465e..5555e987ec4 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorTests.java @@ -21,15 +21,22 @@ package org.elasticsearch.search.aggregations.metrics.tophits; import org.apache.lucene.analysis.core.KeywordAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; +import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.document.StringField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; import org.apache.lucene.queryparser.classic.QueryParser; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.mapper.KeywordFieldMapper; @@ -39,6 +46,7 @@ import org.elasticsearch.index.mapper.UidFieldMapper; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.sort.SortOrder; @@ -148,4 +156,47 @@ public class TopHitsAggregatorTests extends AggregatorTestCase { } return document; } + + public void testSetScorer() throws Exception { + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig() + // only merge adjacent segments + .setMergePolicy(newLogMergePolicy())); + // first window (see BooleanScorer) has matches on one clause only + for (int i = 0; i < 2048; ++i) { + Document doc = new Document(); + doc.add(new StringField("_id", Uid.encodeId(Integer.toString(i)), Store.YES)); + if (i == 1000) { // any doc in 0..2048 + doc.add(new StringField("string", "bar", Store.NO)); + } + w.addDocument(doc); + } + // second window has matches in two clauses + for (int i = 0; i < 2048; ++i) { + Document doc = new Document(); + doc.add(new StringField("_id", Uid.encodeId(Integer.toString(2048 + i)), Store.YES)); + if (i == 500) { // any doc in 0..2048 + doc.add(new StringField("string", "baz", Store.NO)); + } else if (i == 1500) { + doc.add(new StringField("string", "bar", Store.NO)); + } + w.addDocument(doc); + } + + w.forceMerge(1); // we need all docs to be in the same segment + + IndexReader reader = DirectoryReader.open(w); + w.close(); + + IndexSearcher searcher = new IndexSearcher(reader); + Query query = new BooleanQuery.Builder() + .add(new TermQuery(new Term("string", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("string", "baz")), Occur.SHOULD) + .build(); + AggregationBuilder agg = AggregationBuilders.topHits("top_hits"); + TopHits result = searchAndReduce(searcher, query, agg, STRING_FIELD_TYPE); + assertEquals(3, result.getHits().totalHits); + reader.close(); + directory.close(); + } } diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index bc9de3e06aa..d3e83f03d3a 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -91,6 +91,7 @@ import static org.mockito.Mockito.when; public abstract class AggregatorTestCase extends ESTestCase { private static final String NESTEDFIELD_PREFIX = "nested_"; private List releasables = new ArrayList<>(); + private static final String TYPE_NAME = "type"; /** Create a factory for the given aggregation builder. */ protected AggregatorFactory createAggregatorFactory(AggregationBuilder aggregationBuilder, @@ -104,6 +105,7 @@ public abstract class AggregatorTestCase extends ESTestCase { MapperService mapperService = mapperServiceMock(); when(mapperService.getIndexSettings()).thenReturn(indexSettings); when(mapperService.hasNested()).thenReturn(false); + when(mapperService.types()).thenReturn(Collections.singleton(TYPE_NAME)); when(searchContext.mapperService()).thenReturn(mapperService); IndexFieldDataService ifds = new IndexFieldDataService(indexSettings, new IndicesFieldDataCache(Settings.EMPTY, new IndexFieldDataCache.Listener() { @@ -115,7 +117,7 @@ public abstract class AggregatorTestCase extends ESTestCase { } }); - SearchLookup searchLookup = new SearchLookup(mapperService, ifds::getForField, new String[]{"type"}); + SearchLookup searchLookup = new SearchLookup(mapperService, ifds::getForField, new String[]{TYPE_NAME}); when(searchContext.lookup()).thenReturn(searchLookup); QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);