Add a "score_mode" parameter to rescoring query.

Default value is "total", possible values are: "max", "min", "avg",
"multiply" and "total".

- "total": the final score of a document is the sum of the original
  query score with the rescore query score.
- "max": only the highest score count.
- "min": only the lowest score is kept (if the document doesn't match
  the rescore query, the original query score is used).
- "avg": average of both scores
- "multiply": product of both scores

Closes #3258
This commit is contained in:
Cédric HOURCADE 2013-07-20 17:00:04 +02:00 committed by Simon Willnauer
parent 29d337c44b
commit 5022c4659c
3 changed files with 266 additions and 34 deletions

View File

@ -32,11 +32,68 @@ import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.Arrays;
import java.util.Set;
import java.lang.Math;
final class QueryRescorer implements Rescorer {
private static enum ScoreMode {
Avg {
@Override
public float combine(float primary, float secondary) {
return (primary + secondary) / 2;
}
@Override
public String toString() {
return "avg";
}
},
Max {
@Override
public float combine(float primary, float secondary) {
return Math.max(primary, secondary);
}
@Override
public String toString() {
return "max";
}
},
Min {
@Override
public float combine(float primary, float secondary) {
return Math.min(primary, secondary);
}
@Override
public String toString() {
return "min";
}
},
Total {
@Override
public float combine(float primary, float secondary) {
return primary + secondary;
}
@Override
public String toString() {
return "sum";
}
},
Multiply {
@Override
public float combine(float primary, float secondary) {
return primary * secondary;
}
@Override
public String toString() {
return "product";
}
};
public abstract float combine(float primary, float secondary);
}
public static final Rescorer INSTANCE = new QueryRescorer();
public static final String NAME = "query";
@Override
public String name() {
return NAME;
@ -72,20 +129,21 @@ final class QueryRescorer implements Rescorer {
"product of:");
prim.addDetail(primaryExplain);
prim.addDetail(new Explanation(primaryWeight, "primaryWeight"));
if (rescoreExplain != null) {
ComplexExplanation sumExpl = new ComplexExplanation();
sumExpl.setDescription("sum of:");
sumExpl.addDetail(prim);
sumExpl.setMatch(prim.isMatch());
if (rescoreExplain != null && rescoreExplain.isMatch()) {
float secondaryWeight = rescore.rescoreQueryWeight();
ComplexExplanation sec = new ComplexExplanation(rescoreExplain.isMatch(),
rescoreExplain.getValue() * secondaryWeight,
"product of:");
sec.addDetail(rescoreExplain);
sec.addDetail(new Explanation(secondaryWeight, "secondaryWeight"));
sumExpl.addDetail(sec);
sumExpl.setValue(prim.getValue() + sec.getValue());
return sumExpl;
ScoreMode scoreMode = rescore.scoreMode();
ComplexExplanation calcExpl = new ComplexExplanation();
calcExpl.setDescription(scoreMode + " of:");
calcExpl.addDetail(prim);
calcExpl.setMatch(prim.isMatch());
calcExpl.addDetail(sec);
calcExpl.setValue(scoreMode.combine(prim.getValue(), sec.getValue()));
return calcExpl;
} else {
return prim;
}
@ -108,6 +166,21 @@ final class QueryRescorer implements Rescorer {
rescoreContext.setQueryWeight(parser.floatValue());
} else if("rescore_query_weight".equals(fieldName)) {
rescoreContext.setRescoreQueryWeight(parser.floatValue());
} else if ("score_mode".equals(fieldName)) {
String sScoreMode = parser.text();
if ("avg".equals(sScoreMode)) {
rescoreContext.setScoreMode(ScoreMode.Avg);
} else if ("max".equals(sScoreMode)) {
rescoreContext.setScoreMode(ScoreMode.Max);
} else if ("min".equals(sScoreMode)) {
rescoreContext.setScoreMode(ScoreMode.Min);
} else if ("total".equals(sScoreMode)) {
rescoreContext.setScoreMode(ScoreMode.Total);
} else if ("multiply".equals(sScoreMode)) {
rescoreContext.setScoreMode(ScoreMode.Multiply);
} else {
throw new ElasticSearchIllegalArgumentException("[rescore] illegal score_mode [" + sScoreMode + "]");
}
} else {
throw new ElasticSearchIllegalArgumentException("rescore doesn't support [" + fieldName + "]");
}
@ -120,11 +193,13 @@ final class QueryRescorer implements Rescorer {
public QueryRescoreContext(QueryRescorer rescorer) {
super(NAME, 10, rescorer);
this.scoreMode = ScoreMode.Total;
}
private ParsedQuery parsedQuery;
private float queryWeight = 1.0f;
private float rescoreQueryWeight = 1.0f;
private ScoreMode scoreMode;
public void setParsedQuery(ParsedQuery parsedQuery) {
this.parsedQuery = parsedQuery;
@ -142,6 +217,10 @@ final class QueryRescorer implements Rescorer {
return rescoreQueryWeight;
}
public ScoreMode scoreMode() {
return scoreMode;
}
public void setRescoreQueryWeight(float rescoreQueryWeight) {
this.rescoreQueryWeight = rescoreQueryWeight;
}
@ -149,6 +228,10 @@ final class QueryRescorer implements Rescorer {
public void setQueryWeight(float queryWeight) {
this.queryWeight = queryWeight;
}
public void setScoreMode(ScoreMode scoreMode) {
this.scoreMode = scoreMode;
}
}
@ -164,9 +247,10 @@ final class QueryRescorer implements Rescorer {
int j = 0;
float primaryWeight = context.queryWeight();
float secondaryWeight = context.rescoreQueryWeight();
for (int i = 0; i < primaryDocs.length && j < secondaryDocs.length; i++) {
if (primaryDocs[i].doc == secondaryDocs[j].doc) {
primaryDocs[i].score = (primaryDocs[i].score * primaryWeight) + (secondaryDocs[j++].score * secondaryWeight);
ScoreMode scoreMode = context.scoreMode();
for (int i = 0; i < primaryDocs.length; i++) {
if (j < secondaryDocs.length && primaryDocs[i].doc == secondaryDocs[j].doc) {
primaryDocs[i].score = scoreMode.combine(primaryDocs[i].score * primaryWeight, secondaryDocs[j++].score * secondaryWeight);
} else {
primaryDocs[i].score *= primaryWeight;
}

View File

@ -84,6 +84,7 @@ public class RescoreBuilder implements ToXContent {
private QueryBuilder queryBuilder;
private Float rescoreQueryWeight;
private Float queryWeight;
private String scoreMode;
/**
* Creates a new {@link QueryRescorer} instance
@ -109,6 +110,14 @@ public class RescoreBuilder implements ToXContent {
return this;
}
/**
* Sets the original query score mode. The default is <tt>total</tt>
*/
public QueryRescorer setScoreMode(String scoreMode) {
this.scoreMode = scoreMode;
return this;
}
@Override
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("rescore_query", queryBuilder);
@ -118,6 +127,9 @@ public class RescoreBuilder implements ToXContent {
if (rescoreQueryWeight != null) {
builder.field("rescore_query_weight", rescoreQueryWeight);
}
if (scoreMode != null) {
builder.field("score_mode", scoreMode);
}
return builder;
}
}

View File

@ -31,6 +31,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.rescore.RescoreBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer;
import org.elasticsearch.test.integration.AbstractSharedClusterTest;
import org.junit.Test;
@ -280,29 +281,164 @@ public class QueryRescorerTests extends AbstractSharedClusterTest {
.setSource("field1", "quick huge brown", "field2", "the quick lazy huge brown fox jumps over the tree").execute()
.actionGet();
refresh();
SearchResponse searchResponse = client()
.prepareSearch()
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
.setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR))
.setRescorer(
RescoreBuilder.queryRescorer(QueryBuilders.matchPhraseQuery("field1", "the quick brown").slop(2).boost(4.0f))
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f)).setRescoreWindow(5).setExplain(true).execute()
.actionGet();
assertHitCount(searchResponse, 3);
assertFirstHit(searchResponse, hasId("1"));
assertSecondHit(searchResponse, hasId("2"));
assertThirdHit(searchResponse, hasId("3"));
for (int i = 0; i < 3; i++) {
assertThat(searchResponse.getHits().getAt(i).explanation(), notNullValue());
assertThat(searchResponse.getHits().getAt(i).explanation().isMatch(), equalTo(true));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails().length, equalTo(2));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].isMatch(), equalTo(true));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].getDetails()[1].getValue(), equalTo(0.5f));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[1].getValue(), equalTo(0.4f));
if (i == 2) {
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].isMatch(), equalTo(false));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[0].getValue(), equalTo(0.0f));
{
SearchResponse searchResponse = client()
.prepareSearch()
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
.setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR))
.setRescorer(
RescoreBuilder.queryRescorer(QueryBuilders.matchPhraseQuery("field1", "the quick brown").slop(2).boost(4.0f))
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f)).setRescoreWindow(5).setExplain(true).execute()
.actionGet();
assertHitCount(searchResponse, 3);
assertFirstHit(searchResponse, hasId("1"));
assertSecondHit(searchResponse, hasId("2"));
assertThirdHit(searchResponse, hasId("3"));
for (int i = 0; i < 3; i++) {
assertThat(searchResponse.getHits().getAt(i).explanation(), notNullValue());
assertThat(searchResponse.getHits().getAt(i).explanation().isMatch(), equalTo(true));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails().length, equalTo(2));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].isMatch(), equalTo(true));
if (i == 2) {
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getValue(), equalTo(0.5f));
} else {
assertThat(searchResponse.getHits().getAt(i).explanation().getDescription(), equalTo("sum of:"));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].getDetails()[1].getValue(), equalTo(0.5f));
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[1].getValue(), equalTo(0.4f));
}
}
}
String[] scoreModes = new String[]{ "max", "min", "avg", "total", "multiply", "" };
String[] descriptionModes = new String[]{ "max of:", "min of:", "avg of:", "sum of:", "product of:", "sum of:" };
for (int i = 0; i < scoreModes.length; i++) {
QueryRescorer rescoreQuery = RescoreBuilder.queryRescorer(QueryBuilders.matchQuery("field1", "the quick brown").boost(4.0f))
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f);
if (!"".equals(scoreModes[i])) {
rescoreQuery.setScoreMode(scoreModes[i]);
}
SearchResponse searchResponse = client()
.prepareSearch()
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
.setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR))
.setRescorer(rescoreQuery).setRescoreWindow(5).setExplain(true).execute()
.actionGet();
assertHitCount(searchResponse, 3);
assertFirstHit(searchResponse, hasId("1"));
assertSecondHit(searchResponse, hasId("2"));
assertThirdHit(searchResponse, hasId("3"));
for (int j = 0; j < 3; j++) {
assertThat(searchResponse.getHits().getAt(j).explanation().getDescription(), equalTo(descriptionModes[i]));
}
}
}
@Test
public void testScoring() throws Exception {
client().admin()
.indices()
.prepareCreate("test")
.addMapping(
"type1",
jsonBuilder().startObject().startObject("type1").startObject("properties").startObject("field1")
.field("index", "not_analyzed").field("type", "string").endObject().endObject().endObject().endObject())
.setSettings(ImmutableSettings.settingsBuilder()).execute().actionGet();
ensureGreen();
int numDocs = 1000;
for (int i = 0; i < numDocs; i++) {
client().prepareIndex("test", "type1", String.valueOf(i)).setSource("field1", English.intToEnglish(i)).execute().actionGet();
}
flush();
optimize(); // make sure we don't have a background merge running
refresh();
ensureGreen();
String[] scoreModes = new String[]{ "max", "min", "avg", "total", "multiply", "" };
float primaryWeight = 1.1f;
float secondaryWeight = 1.6f;
for (String scoreMode: scoreModes) {
for (int i = 0; i < numDocs - 4; i++) {
String[] intToEnglish = new String[] { English.intToEnglish(i), English.intToEnglish(i + 1), English.intToEnglish(i + 2), English.intToEnglish(i + 3) };
QueryRescorer rescoreQuery = RescoreBuilder
.queryRescorer(
QueryBuilders.boolQuery()
.disableCoord(true)
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).script("5.0f"))
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).script("7.0f"))
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).script("0.0f")))
.setQueryWeight(primaryWeight)
.setRescoreQueryWeight(secondaryWeight);
if (!"".equals(scoreMode)) {
rescoreQuery.setScoreMode(scoreMode);
}
SearchResponse rescored = client()
.prepareSearch()
.setPreference("test") // ensure we hit the same shards for tie-breaking
.setQuery(QueryBuilders.boolQuery()
.disableCoord(true)
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).script("2.0f"))
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).script("3.0f"))
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).script("5.0f"))
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).script("0.2f")))
.setFrom(0)
.setSize(10)
.setRescorer(rescoreQuery)
.setRescoreWindow(50).execute().actionGet();
assertHitCount(rescored, 4);
if ("total".equals(scoreMode) || "".equals(scoreMode)) {
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
assertSecondHit(rescored, hasId(String.valueOf(i)));
assertThirdHit(rescored, hasId(String.valueOf(i + 2)));
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(3.0f * primaryWeight + 7.0f * secondaryWeight));
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(2.0f * primaryWeight + 5.0f * secondaryWeight));
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight + 0.0f * secondaryWeight));
} else if ("max".equals(scoreMode)) {
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
assertSecondHit(rescored, hasId(String.valueOf(i)));
assertThirdHit(rescored, hasId(String.valueOf(i + 2)));
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(7.0f * secondaryWeight));
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(5.0f * secondaryWeight));
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight));
} else if ("min".equals(scoreMode)) {
assertFirstHit(rescored, hasId(String.valueOf(i + 2)));
assertSecondHit(rescored, hasId(String.valueOf(i + 1)));
assertThirdHit(rescored, hasId(String.valueOf(i)));
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(5.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(3.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(2.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.0f * secondaryWeight));
} else if ("avg".equals(scoreMode)) {
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
assertSecondHit(rescored, hasId(String.valueOf(i + 2)));
assertThirdHit(rescored, hasId(String.valueOf(i)));
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo((3.0f * primaryWeight + 7.0f * secondaryWeight) / 2.0f));
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(5.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo((2.0f * primaryWeight + 5.0f * secondaryWeight) / 2.0f));
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo((0.2f * primaryWeight) / 2.0f));
} else if ("multiply".equals(scoreMode)) {
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
assertSecondHit(rescored, hasId(String.valueOf(i)));
assertThirdHit(rescored, hasId(String.valueOf(i + 2)));
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(3.0f * primaryWeight * 7.0f * secondaryWeight));
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(2.0f * primaryWeight * 5.0f * secondaryWeight));
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight));
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight * 0.0f * secondaryWeight));
}
}
}
}