Rescorer should be applied in the TopHits aggregation (#20978)

When using a top hits aggregation the rescorer are ignored.
This change applies the rescorer to the top hits of each bucket.

Fixes #19317
This commit is contained in:
Jim Ferenczi 2016-10-20 12:50:49 +02:00 committed by GitHub
parent adb30ac091
commit 1b822cc7ef
2 changed files with 117 additions and 3 deletions

View File

@ -29,6 +29,7 @@ import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.LongObjectPagedHashMap;
@ -44,6 +45,7 @@ import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.internal.InternalSearchHit;
import org.elasticsearch.search.internal.InternalSearchHits;
import org.elasticsearch.search.internal.SubSearchContext;
import org.elasticsearch.search.rescore.RescoreSearchContext;
import org.elasticsearch.search.sort.SortAndFormats;
import java.io.IOException;
@ -112,6 +114,11 @@ public class TopHitsAggregator extends MetricsAggregator {
if (collectors == null) {
SortAndFormats sort = subSearchContext.sort();
int topN = subSearchContext.from() + subSearchContext.size();
if (sort == null) {
for (RescoreSearchContext rescoreContext : context.searchContext().rescore()) {
topN = Math.max(rescoreContext.window(), topN);
}
}
// In the QueryPhase we don't need this protection, because it is build into the IndexSearcher,
// but here we create collectors ourselves and we need prevent OOM because of crazy an offset and size.
topN = Math.min(topN, subSearchContext.searcher().getIndexReader().maxDoc());
@ -133,9 +140,18 @@ public class TopHitsAggregator extends MetricsAggregator {
if (topDocsCollector == null) {
topHits = buildEmptyAggregation();
} else {
final TopDocs topDocs = topDocsCollector.topLevelCollector.topDocs();
TopDocs topDocs = topDocsCollector.topLevelCollector.topDocs();
subSearchContext.queryResult().topDocs(topDocs, subSearchContext.sort() == null ? null : subSearchContext.sort().formats);
if (subSearchContext.sort() == null) {
for (RescoreSearchContext ctx : context().searchContext().rescore()) {
try {
topDocs = ctx.rescorer().rescore(topDocs, context.searchContext(), ctx);
} catch (IOException e) {
throw new ElasticsearchException("Rescore TopHits Failed", e);
}
}
}
int[] docIdsToLoad = new int[topDocs.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
@ -155,7 +171,7 @@ public class TopHitsAggregator extends MetricsAggregator {
}
}
topHits = new InternalTopHits(name, subSearchContext.from(), subSearchContext.size(), topDocs, fetchResult.hits(), pipelineAggregators(),
metaData());
metaData());
}
return topHits;
}

View File

@ -28,6 +28,7 @@ import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptEngine;
@ -49,6 +50,7 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.sort.ScriptSortBuilder.ScriptSortType;
import org.elasticsearch.search.rescore.RescoreBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESIntegTestCase;
@ -1043,4 +1045,100 @@ public class TopHitsIT extends ESIntegTestCase {
assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache()
.getMissCount(), equalTo(1L));
}
public void testWithRescore() {
// Rescore with default sort on relevancy (score)
{
SearchResponse response = client()
.prepareSearch("idx")
.addRescorer(
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
)
.setTypes("type")
.addAggregation(terms("terms")
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits")
)
)
.get();
Terms terms = response.getAggregations().get("terms");
for (Terms.Bucket bucket : terms.getBuckets()) {
TopHits topHits = bucket.getAggregations().get("hits");
for (SearchHit hit : topHits.getHits().getHits()) {
assertThat(hit.score(), equalTo(4.0f));
}
}
}
{
SearchResponse response = client()
.prepareSearch("idx")
.addRescorer(
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
)
.setTypes("type")
.addAggregation(terms("terms")
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").sort(SortBuilders.scoreSort())
)
)
.get();
Terms terms = response.getAggregations().get("terms");
for (Terms.Bucket bucket : terms.getBuckets()) {
TopHits topHits = bucket.getAggregations().get("hits");
for (SearchHit hit : topHits.getHits().getHits()) {
assertThat(hit.score(), equalTo(4.0f));
}
}
}
// Rescore should not be applied if the sort order is not relevancy
{
SearchResponse response = client()
.prepareSearch("idx")
.addRescorer(
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
)
.setTypes("type")
.addAggregation(terms("terms")
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").sort(SortBuilders.fieldSort("_type"))
)
)
.get();
Terms terms = response.getAggregations().get("terms");
for (Terms.Bucket bucket : terms.getBuckets()) {
TopHits topHits = bucket.getAggregations().get("hits");
for (SearchHit hit : topHits.getHits().getHits()) {
assertThat(hit.score(), equalTo(Float.NaN));
}
}
}
{
SearchResponse response = client()
.prepareSearch("idx")
.addRescorer(
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
)
.setTypes("type")
.addAggregation(terms("terms")
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").sort(SortBuilders.scoreSort()).sort(SortBuilders.fieldSort("_type"))
)
)
.get();
Terms terms = response.getAggregations().get("terms");
for (Terms.Bucket bucket : terms.getBuckets()) {
TopHits topHits = bucket.getAggregations().get("hits");
for (SearchHit hit : topHits.getHits().getHits()) {
assertThat(hit.score(), equalTo(Float.NaN));
}
}
}
}
}