Rescorer should be applied in the TopHits aggregation (#20978)
When using a top hits aggregation the rescorer are ignored. This change applies the rescorer to the top hits of each bucket. Fixes #19317
This commit is contained in:
parent
adb30ac091
commit
1b822cc7ef
|
@ -29,6 +29,7 @@ import org.apache.lucene.search.TopDocsCollector;
|
|||
import org.apache.lucene.search.TopFieldCollector;
|
||||
import org.apache.lucene.search.TopFieldDocs;
|
||||
import org.apache.lucene.search.TopScoreDocCollector;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.lease.Releasables;
|
||||
import org.elasticsearch.common.lucene.Lucene;
|
||||
import org.elasticsearch.common.util.LongObjectPagedHashMap;
|
||||
|
@ -44,6 +45,7 @@ import org.elasticsearch.search.fetch.FetchSearchResult;
|
|||
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||
import org.elasticsearch.search.internal.InternalSearchHits;
|
||||
import org.elasticsearch.search.internal.SubSearchContext;
|
||||
import org.elasticsearch.search.rescore.RescoreSearchContext;
|
||||
import org.elasticsearch.search.sort.SortAndFormats;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -112,6 +114,11 @@ public class TopHitsAggregator extends MetricsAggregator {
|
|||
if (collectors == null) {
|
||||
SortAndFormats sort = subSearchContext.sort();
|
||||
int topN = subSearchContext.from() + subSearchContext.size();
|
||||
if (sort == null) {
|
||||
for (RescoreSearchContext rescoreContext : context.searchContext().rescore()) {
|
||||
topN = Math.max(rescoreContext.window(), topN);
|
||||
}
|
||||
}
|
||||
// In the QueryPhase we don't need this protection, because it is build into the IndexSearcher,
|
||||
// but here we create collectors ourselves and we need prevent OOM because of crazy an offset and size.
|
||||
topN = Math.min(topN, subSearchContext.searcher().getIndexReader().maxDoc());
|
||||
|
@ -133,9 +140,18 @@ public class TopHitsAggregator extends MetricsAggregator {
|
|||
if (topDocsCollector == null) {
|
||||
topHits = buildEmptyAggregation();
|
||||
} else {
|
||||
final TopDocs topDocs = topDocsCollector.topLevelCollector.topDocs();
|
||||
|
||||
TopDocs topDocs = topDocsCollector.topLevelCollector.topDocs();
|
||||
subSearchContext.queryResult().topDocs(topDocs, subSearchContext.sort() == null ? null : subSearchContext.sort().formats);
|
||||
if (subSearchContext.sort() == null) {
|
||||
for (RescoreSearchContext ctx : context().searchContext().rescore()) {
|
||||
try {
|
||||
topDocs = ctx.rescorer().rescore(topDocs, context.searchContext(), ctx);
|
||||
} catch (IOException e) {
|
||||
throw new ElasticsearchException("Rescore TopHits Failed", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int[] docIdsToLoad = new int[topDocs.scoreDocs.length];
|
||||
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
|
||||
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.elasticsearch.action.search.SearchType;
|
|||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.script.MockScriptEngine;
|
||||
|
@ -49,6 +50,7 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
|
|||
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
|
||||
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
|
||||
import org.elasticsearch.search.sort.ScriptSortBuilder.ScriptSortType;
|
||||
import org.elasticsearch.search.rescore.RescoreBuilder;
|
||||
import org.elasticsearch.search.sort.SortBuilders;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
|
@ -1043,4 +1045,100 @@ public class TopHitsIT extends ESIntegTestCase {
|
|||
assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache()
|
||||
.getMissCount(), equalTo(1L));
|
||||
}
|
||||
|
||||
public void testWithRescore() {
|
||||
// Rescore with default sort on relevancy (score)
|
||||
{
|
||||
SearchResponse response = client()
|
||||
.prepareSearch("idx")
|
||||
.addRescorer(
|
||||
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
|
||||
)
|
||||
.setTypes("type")
|
||||
.addAggregation(terms("terms")
|
||||
.field(TERMS_AGGS_FIELD)
|
||||
.subAggregation(
|
||||
topHits("hits")
|
||||
)
|
||||
)
|
||||
.get();
|
||||
Terms terms = response.getAggregations().get("terms");
|
||||
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||
TopHits topHits = bucket.getAggregations().get("hits");
|
||||
for (SearchHit hit : topHits.getHits().getHits()) {
|
||||
assertThat(hit.score(), equalTo(4.0f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
SearchResponse response = client()
|
||||
.prepareSearch("idx")
|
||||
.addRescorer(
|
||||
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
|
||||
)
|
||||
.setTypes("type")
|
||||
.addAggregation(terms("terms")
|
||||
.field(TERMS_AGGS_FIELD)
|
||||
.subAggregation(
|
||||
topHits("hits").sort(SortBuilders.scoreSort())
|
||||
)
|
||||
)
|
||||
.get();
|
||||
Terms terms = response.getAggregations().get("terms");
|
||||
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||
TopHits topHits = bucket.getAggregations().get("hits");
|
||||
for (SearchHit hit : topHits.getHits().getHits()) {
|
||||
assertThat(hit.score(), equalTo(4.0f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescore should not be applied if the sort order is not relevancy
|
||||
{
|
||||
SearchResponse response = client()
|
||||
.prepareSearch("idx")
|
||||
.addRescorer(
|
||||
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
|
||||
)
|
||||
.setTypes("type")
|
||||
.addAggregation(terms("terms")
|
||||
.field(TERMS_AGGS_FIELD)
|
||||
.subAggregation(
|
||||
topHits("hits").sort(SortBuilders.fieldSort("_type"))
|
||||
)
|
||||
)
|
||||
.get();
|
||||
Terms terms = response.getAggregations().get("terms");
|
||||
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||
TopHits topHits = bucket.getAggregations().get("hits");
|
||||
for (SearchHit hit : topHits.getHits().getHits()) {
|
||||
assertThat(hit.score(), equalTo(Float.NaN));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
SearchResponse response = client()
|
||||
.prepareSearch("idx")
|
||||
.addRescorer(
|
||||
RescoreBuilder.queryRescorer(new MatchAllQueryBuilder().boost(3.0f))
|
||||
)
|
||||
.setTypes("type")
|
||||
.addAggregation(terms("terms")
|
||||
.field(TERMS_AGGS_FIELD)
|
||||
.subAggregation(
|
||||
topHits("hits").sort(SortBuilders.scoreSort()).sort(SortBuilders.fieldSort("_type"))
|
||||
)
|
||||
)
|
||||
.get();
|
||||
Terms terms = response.getAggregations().get("terms");
|
||||
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||
TopHits topHits = bucket.getAggregations().get("hits");
|
||||
for (SearchHit hit : topHits.getHits().getHits()) {
|
||||
assertThat(hit.score(), equalTo(Float.NaN));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue