Max score should be updated when a rescorer is used (#20977)

The max score returned in the response of a query does not take rescorer into account.
This change updates the max_score when a rescorer is used in a query.
Fixes #20651
This commit is contained in:
Jim Ferenczi 2016-10-20 12:38:28 +02:00 committed by GitHub
parent d0bbe89c16
commit adb30ac091
2 changed files with 15 additions and 4 deletions

View File

@ -24,10 +24,6 @@ import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
@ -159,6 +155,8 @@ public final class QueryRescorer implements Rescorer {
// incoming first pass hits, instead of allowing recoring of just the top subset: // incoming first pass hits, instead of allowing recoring of just the top subset:
Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR); Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR);
} }
// update the max score after the resort
in.setMaxScore(in.scoreDocs[0].score);
return in; return in;
} }

View File

@ -66,6 +66,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThir
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.notNullValue;
@ -97,6 +98,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
numDocsWith100AsAScore += 1; numDocsWith100AsAScore += 1;
} }
} }
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
// we cannot assert that they are equal since some shards might not have docs at all // we cannot assert that they are equal since some shards might not have docs at all
assertThat(numDocsWith100AsAScore, lessThanOrEqualTo(numShards)); assertThat(numDocsWith100AsAScore, lessThanOrEqualTo(numShards));
} }
@ -122,6 +124,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setRescoreQueryWeight(2), 5).execute().actionGet(); .setRescoreQueryWeight(2), 5).execute().actionGet();
assertThat(searchResponse.getHits().totalHits(), equalTo(3L)); assertThat(searchResponse.getHits().totalHits(), equalTo(3L));
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertThat(searchResponse.getHits().getHits()[0].getId(), equalTo("1")); assertThat(searchResponse.getHits().getHits()[0].getId(), equalTo("1"));
assertThat(searchResponse.getHits().getHits()[1].getId(), equalTo("3")); assertThat(searchResponse.getHits().getHits()[1].getId(), equalTo("3"));
assertThat(searchResponse.getHits().getHits()[2].getId(), equalTo("2")); assertThat(searchResponse.getHits().getHits()[2].getId(), equalTo("2"));
@ -142,6 +145,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.actionGet(); .actionGet();
assertHitCount(searchResponse, 3); assertHitCount(searchResponse, 3);
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("1")); assertFirstHit(searchResponse, hasId("1"));
assertSecondHit(searchResponse, hasId("2")); assertSecondHit(searchResponse, hasId("2"));
assertThirdHit(searchResponse, hasId("3")); assertThirdHit(searchResponse, hasId("3"));
@ -203,6 +207,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
assertThat(searchResponse.getHits().hits().length, equalTo(5)); assertThat(searchResponse.getHits().hits().length, equalTo(5));
assertHitCount(searchResponse, 9); assertHitCount(searchResponse, 9);
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("2")); assertFirstHit(searchResponse, hasId("2"));
assertSecondHit(searchResponse, hasId("6")); assertSecondHit(searchResponse, hasId("6"));
assertThirdHit(searchResponse, hasId("3")); assertThirdHit(searchResponse, hasId("3"));
@ -219,6 +224,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
assertThat(searchResponse.getHits().hits().length, equalTo(5)); assertThat(searchResponse.getHits().hits().length, equalTo(5));
assertHitCount(searchResponse, 9); assertHitCount(searchResponse, 9);
assertThat(searchResponse.getHits().maxScore(), greaterThan(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("3")); assertFirstHit(searchResponse, hasId("3"));
} }
@ -252,6 +258,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setSize(5).execute().actionGet(); .setSize(5).execute().actionGet();
assertThat(searchResponse.getHits().hits().length, equalTo(4)); assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertHitCount(searchResponse, 4); assertHitCount(searchResponse, 4);
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("3")); assertFirstHit(searchResponse, hasId("3"));
assertSecondHit(searchResponse, hasId("6")); assertSecondHit(searchResponse, hasId("6"));
assertThirdHit(searchResponse, hasId("1")); assertThirdHit(searchResponse, hasId("1"));
@ -268,6 +275,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
// Only top 2 hits were re-ordered: // Only top 2 hits were re-ordered:
assertThat(searchResponse.getHits().hits().length, equalTo(4)); assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertHitCount(searchResponse, 4); assertHitCount(searchResponse, 4);
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("6")); assertFirstHit(searchResponse, hasId("6"));
assertSecondHit(searchResponse, hasId("3")); assertSecondHit(searchResponse, hasId("3"));
assertThirdHit(searchResponse, hasId("1")); assertThirdHit(searchResponse, hasId("1"));
@ -285,6 +293,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
// Only top 3 hits were re-ordered: // Only top 3 hits were re-ordered:
assertThat(searchResponse.getHits().hits().length, equalTo(4)); assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertHitCount(searchResponse, 4); assertHitCount(searchResponse, 4);
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("6")); assertFirstHit(searchResponse, hasId("6"));
assertSecondHit(searchResponse, hasId("1")); assertSecondHit(searchResponse, hasId("1"));
assertThirdHit(searchResponse, hasId("3")); assertThirdHit(searchResponse, hasId("3"));
@ -321,6 +330,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setSize(5).execute().actionGet(); .setSize(5).execute().actionGet();
assertThat(searchResponse.getHits().hits().length, equalTo(4)); assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertHitCount(searchResponse, 4); assertHitCount(searchResponse, 4);
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("3")); assertFirstHit(searchResponse, hasId("3"));
assertSecondHit(searchResponse, hasId("6")); assertSecondHit(searchResponse, hasId("6"));
assertThirdHit(searchResponse, hasId("1")); assertThirdHit(searchResponse, hasId("1"));
@ -336,6 +346,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setQueryWeight(1.0f).setRescoreQueryWeight(-1f), 3).execute().actionGet(); .setQueryWeight(1.0f).setRescoreQueryWeight(-1f), 3).execute().actionGet();
// 6 and 1 got worse, and then the hit (2) outside the rescore window were sorted ahead: // 6 and 1 got worse, and then the hit (2) outside the rescore window were sorted ahead:
assertThat(searchResponse.getHits().maxScore(), equalTo(searchResponse.getHits().getHits()[0].score()));
assertFirstHit(searchResponse, hasId("3")); assertFirstHit(searchResponse, hasId("3"));
assertSecondHit(searchResponse, hasId("2")); assertSecondHit(searchResponse, hasId("2"));
assertThirdHit(searchResponse, hasId("6")); assertThirdHit(searchResponse, hasId("6"));
@ -595,6 +606,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
assertHitCount(rescored, 4); assertHitCount(rescored, 4);
assertThat(rescored.getHits().maxScore(), equalTo(rescored.getHits().getHits()[0].score()));
if ("total".equals(scoreMode) || "".equals(scoreMode)) { if ("total".equals(scoreMode) || "".equals(scoreMode)) {
assertFirstHit(rescored, hasId(String.valueOf(i + 1))); assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
assertSecondHit(rescored, hasId(String.valueOf(i))); assertSecondHit(rescored, hasId(String.valueOf(i)));
@ -672,6 +684,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10); request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10);
response = request.setSize(2).get(); response = request.setSize(2).get();
assertThat(response.getHits().maxScore(), equalTo(response.getHits().getHits()[0].score()));
assertFirstHit(response, hasId("91")); assertFirstHit(response, hasId("91"));
assertFirstHit(response, hasScore(2001.0f)); assertFirstHit(response, hasScore(2001.0f));
assertSecondHit(response, hasScore(1001.0f)); // Not sure which one it is but it is ninety something assertSecondHit(response, hasScore(1001.0f)); // Not sure which one it is but it is ninety something