TopHitsAggregator must propagate calls to `setScorer`. (#27138)

It is required in order to work correctly with bulk scorer implementations
that change the scorer during the collection process. Otherwise sub collectors
might call `Scorer.score()` on the wrong scorer.

Closes #27131
This commit is contained in:
Adrien Grand 2017-10-31 09:59:06 +01:00 committed by GitHub
parent a566942219
commit 3812d3cb43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 1 deletions

View File

@ -20,6 +20,8 @@
package org.elasticsearch.search.aggregations.metrics.tophits; package org.elasticsearch.search.aggregations.metrics.tophits;
import com.carrotsearch.hppc.LongObjectHashMap; import com.carrotsearch.hppc.LongObjectHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.LeafCollector;
@ -93,6 +95,9 @@ public class TopHitsAggregator extends MetricsAggregator {
public void setScorer(Scorer scorer) throws IOException { public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer; this.scorer = scorer;
super.setScorer(scorer); super.setScorer(scorer);
for (ObjectCursor<LeafCollector> cursor : leafCollectors.values()) {
cursor.value.setScorer(scorer);
}
} }
@Override @Override

View File

@ -21,15 +21,22 @@ package org.elasticsearch.search.aggregations.metrics.tophits;
import org.apache.lucene.analysis.core.KeywordAnalyzer; import org.apache.lucene.analysis.core.KeywordAnalyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.queryparser.classic.QueryParser; import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper;
@ -39,6 +46,7 @@ import org.elasticsearch.index.mapper.UidFieldMapper;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
@ -148,4 +156,47 @@ public class TopHitsAggregatorTests extends AggregatorTestCase {
} }
return document; return document;
} }
public void testSetScorer() throws Exception {
Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig()
// only merge adjacent segments
.setMergePolicy(newLogMergePolicy()));
// first window (see BooleanScorer) has matches on one clause only
for (int i = 0; i < 2048; ++i) {
Document doc = new Document();
doc.add(new StringField("_id", Uid.encodeId(Integer.toString(i)), Store.YES));
if (i == 1000) { // any doc in 0..2048
doc.add(new StringField("string", "bar", Store.NO));
}
w.addDocument(doc);
}
// second window has matches in two clauses
for (int i = 0; i < 2048; ++i) {
Document doc = new Document();
doc.add(new StringField("_id", Uid.encodeId(Integer.toString(2048 + i)), Store.YES));
if (i == 500) { // any doc in 0..2048
doc.add(new StringField("string", "baz", Store.NO));
} else if (i == 1500) {
doc.add(new StringField("string", "bar", Store.NO));
}
w.addDocument(doc);
}
w.forceMerge(1); // we need all docs to be in the same segment
IndexReader reader = DirectoryReader.open(w);
w.close();
IndexSearcher searcher = new IndexSearcher(reader);
Query query = new BooleanQuery.Builder()
.add(new TermQuery(new Term("string", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("string", "baz")), Occur.SHOULD)
.build();
AggregationBuilder agg = AggregationBuilders.topHits("top_hits");
TopHits result = searchAndReduce(searcher, query, agg, STRING_FIELD_TYPE);
assertEquals(3, result.getHits().totalHits);
reader.close();
directory.close();
}
} }

View File

@ -91,6 +91,7 @@ import static org.mockito.Mockito.when;
public abstract class AggregatorTestCase extends ESTestCase { public abstract class AggregatorTestCase extends ESTestCase {
private static final String NESTEDFIELD_PREFIX = "nested_"; private static final String NESTEDFIELD_PREFIX = "nested_";
private List<Releasable> releasables = new ArrayList<>(); private List<Releasable> releasables = new ArrayList<>();
private static final String TYPE_NAME = "type";
/** Create a factory for the given aggregation builder. */ /** Create a factory for the given aggregation builder. */
protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder, protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
@ -104,6 +105,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
MapperService mapperService = mapperServiceMock(); MapperService mapperService = mapperServiceMock();
when(mapperService.getIndexSettings()).thenReturn(indexSettings); when(mapperService.getIndexSettings()).thenReturn(indexSettings);
when(mapperService.hasNested()).thenReturn(false); when(mapperService.hasNested()).thenReturn(false);
when(mapperService.types()).thenReturn(Collections.singleton(TYPE_NAME));
when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.mapperService()).thenReturn(mapperService);
IndexFieldDataService ifds = new IndexFieldDataService(indexSettings, IndexFieldDataService ifds = new IndexFieldDataService(indexSettings,
new IndicesFieldDataCache(Settings.EMPTY, new IndexFieldDataCache.Listener() { new IndicesFieldDataCache(Settings.EMPTY, new IndexFieldDataCache.Listener() {
@ -115,7 +117,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
} }
}); });
SearchLookup searchLookup = new SearchLookup(mapperService, ifds::getForField, new String[]{"type"}); SearchLookup searchLookup = new SearchLookup(mapperService, ifds::getForField, new String[]{TYPE_NAME});
when(searchContext.lookup()).thenReturn(searchLookup); when(searchContext.lookup()).thenReturn(searchLookup);
QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService); QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);