Add a "score_mode" parameter to rescoring query.
Default value is "total", possible values are: "max", "min", "avg", "multiply" and "total". - "total": the final score of a document is the sum of the original query score with the rescore query score. - "max": only the highest score count. - "min": only the lowest score is kept (if the document doesn't match the rescore query, the original query score is used). - "avg": average of both scores - "multiply": product of both scores Closes #3258
This commit is contained in:
parent
29d337c44b
commit
5022c4659c
|
@ -32,11 +32,68 @@ import org.elasticsearch.search.internal.SearchContext;
|
|||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Set;
|
||||
import java.lang.Math;
|
||||
|
||||
final class QueryRescorer implements Rescorer {
|
||||
|
||||
|
||||
private static enum ScoreMode {
|
||||
Avg {
|
||||
@Override
|
||||
public float combine(float primary, float secondary) {
|
||||
return (primary + secondary) / 2;
|
||||
}
|
||||
@Override
|
||||
public String toString() {
|
||||
return "avg";
|
||||
}
|
||||
},
|
||||
Max {
|
||||
@Override
|
||||
public float combine(float primary, float secondary) {
|
||||
return Math.max(primary, secondary);
|
||||
}
|
||||
@Override
|
||||
public String toString() {
|
||||
return "max";
|
||||
}
|
||||
},
|
||||
Min {
|
||||
@Override
|
||||
public float combine(float primary, float secondary) {
|
||||
return Math.min(primary, secondary);
|
||||
}
|
||||
@Override
|
||||
public String toString() {
|
||||
return "min";
|
||||
}
|
||||
},
|
||||
Total {
|
||||
@Override
|
||||
public float combine(float primary, float secondary) {
|
||||
return primary + secondary;
|
||||
}
|
||||
@Override
|
||||
public String toString() {
|
||||
return "sum";
|
||||
}
|
||||
},
|
||||
Multiply {
|
||||
@Override
|
||||
public float combine(float primary, float secondary) {
|
||||
return primary * secondary;
|
||||
}
|
||||
@Override
|
||||
public String toString() {
|
||||
return "product";
|
||||
}
|
||||
};
|
||||
|
||||
public abstract float combine(float primary, float secondary);
|
||||
}
|
||||
|
||||
public static final Rescorer INSTANCE = new QueryRescorer();
|
||||
public static final String NAME = "query";
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return NAME;
|
||||
|
@ -72,20 +129,21 @@ final class QueryRescorer implements Rescorer {
|
|||
"product of:");
|
||||
prim.addDetail(primaryExplain);
|
||||
prim.addDetail(new Explanation(primaryWeight, "primaryWeight"));
|
||||
if (rescoreExplain != null) {
|
||||
ComplexExplanation sumExpl = new ComplexExplanation();
|
||||
sumExpl.setDescription("sum of:");
|
||||
sumExpl.addDetail(prim);
|
||||
sumExpl.setMatch(prim.isMatch());
|
||||
if (rescoreExplain != null && rescoreExplain.isMatch()) {
|
||||
float secondaryWeight = rescore.rescoreQueryWeight();
|
||||
ComplexExplanation sec = new ComplexExplanation(rescoreExplain.isMatch(),
|
||||
rescoreExplain.getValue() * secondaryWeight,
|
||||
"product of:");
|
||||
sec.addDetail(rescoreExplain);
|
||||
sec.addDetail(new Explanation(secondaryWeight, "secondaryWeight"));
|
||||
sumExpl.addDetail(sec);
|
||||
sumExpl.setValue(prim.getValue() + sec.getValue());
|
||||
return sumExpl;
|
||||
ScoreMode scoreMode = rescore.scoreMode();
|
||||
ComplexExplanation calcExpl = new ComplexExplanation();
|
||||
calcExpl.setDescription(scoreMode + " of:");
|
||||
calcExpl.addDetail(prim);
|
||||
calcExpl.setMatch(prim.isMatch());
|
||||
calcExpl.addDetail(sec);
|
||||
calcExpl.setValue(scoreMode.combine(prim.getValue(), sec.getValue()));
|
||||
return calcExpl;
|
||||
} else {
|
||||
return prim;
|
||||
}
|
||||
|
@ -108,6 +166,21 @@ final class QueryRescorer implements Rescorer {
|
|||
rescoreContext.setQueryWeight(parser.floatValue());
|
||||
} else if("rescore_query_weight".equals(fieldName)) {
|
||||
rescoreContext.setRescoreQueryWeight(parser.floatValue());
|
||||
} else if ("score_mode".equals(fieldName)) {
|
||||
String sScoreMode = parser.text();
|
||||
if ("avg".equals(sScoreMode)) {
|
||||
rescoreContext.setScoreMode(ScoreMode.Avg);
|
||||
} else if ("max".equals(sScoreMode)) {
|
||||
rescoreContext.setScoreMode(ScoreMode.Max);
|
||||
} else if ("min".equals(sScoreMode)) {
|
||||
rescoreContext.setScoreMode(ScoreMode.Min);
|
||||
} else if ("total".equals(sScoreMode)) {
|
||||
rescoreContext.setScoreMode(ScoreMode.Total);
|
||||
} else if ("multiply".equals(sScoreMode)) {
|
||||
rescoreContext.setScoreMode(ScoreMode.Multiply);
|
||||
} else {
|
||||
throw new ElasticSearchIllegalArgumentException("[rescore] illegal score_mode [" + sScoreMode + "]");
|
||||
}
|
||||
} else {
|
||||
throw new ElasticSearchIllegalArgumentException("rescore doesn't support [" + fieldName + "]");
|
||||
}
|
||||
|
@ -120,11 +193,13 @@ final class QueryRescorer implements Rescorer {
|
|||
|
||||
public QueryRescoreContext(QueryRescorer rescorer) {
|
||||
super(NAME, 10, rescorer);
|
||||
this.scoreMode = ScoreMode.Total;
|
||||
}
|
||||
|
||||
private ParsedQuery parsedQuery;
|
||||
private float queryWeight = 1.0f;
|
||||
private float rescoreQueryWeight = 1.0f;
|
||||
private ScoreMode scoreMode;
|
||||
|
||||
public void setParsedQuery(ParsedQuery parsedQuery) {
|
||||
this.parsedQuery = parsedQuery;
|
||||
|
@ -142,6 +217,10 @@ final class QueryRescorer implements Rescorer {
|
|||
return rescoreQueryWeight;
|
||||
}
|
||||
|
||||
public ScoreMode scoreMode() {
|
||||
return scoreMode;
|
||||
}
|
||||
|
||||
public void setRescoreQueryWeight(float rescoreQueryWeight) {
|
||||
this.rescoreQueryWeight = rescoreQueryWeight;
|
||||
}
|
||||
|
@ -149,6 +228,10 @@ final class QueryRescorer implements Rescorer {
|
|||
public void setQueryWeight(float queryWeight) {
|
||||
this.queryWeight = queryWeight;
|
||||
}
|
||||
|
||||
public void setScoreMode(ScoreMode scoreMode) {
|
||||
this.scoreMode = scoreMode;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -164,9 +247,10 @@ final class QueryRescorer implements Rescorer {
|
|||
int j = 0;
|
||||
float primaryWeight = context.queryWeight();
|
||||
float secondaryWeight = context.rescoreQueryWeight();
|
||||
for (int i = 0; i < primaryDocs.length && j < secondaryDocs.length; i++) {
|
||||
if (primaryDocs[i].doc == secondaryDocs[j].doc) {
|
||||
primaryDocs[i].score = (primaryDocs[i].score * primaryWeight) + (secondaryDocs[j++].score * secondaryWeight);
|
||||
ScoreMode scoreMode = context.scoreMode();
|
||||
for (int i = 0; i < primaryDocs.length; i++) {
|
||||
if (j < secondaryDocs.length && primaryDocs[i].doc == secondaryDocs[j].doc) {
|
||||
primaryDocs[i].score = scoreMode.combine(primaryDocs[i].score * primaryWeight, secondaryDocs[j++].score * secondaryWeight);
|
||||
} else {
|
||||
primaryDocs[i].score *= primaryWeight;
|
||||
}
|
||||
|
|
|
@ -84,6 +84,7 @@ public class RescoreBuilder implements ToXContent {
|
|||
private QueryBuilder queryBuilder;
|
||||
private Float rescoreQueryWeight;
|
||||
private Float queryWeight;
|
||||
private String scoreMode;
|
||||
|
||||
/**
|
||||
* Creates a new {@link QueryRescorer} instance
|
||||
|
@ -109,6 +110,14 @@ public class RescoreBuilder implements ToXContent {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the original query score mode. The default is <tt>total</tt>
|
||||
*/
|
||||
public QueryRescorer setScoreMode(String scoreMode) {
|
||||
this.scoreMode = scoreMode;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field("rescore_query", queryBuilder);
|
||||
|
@ -118,6 +127,9 @@ public class RescoreBuilder implements ToXContent {
|
|||
if (rescoreQueryWeight != null) {
|
||||
builder.field("rescore_query_weight", rescoreQueryWeight);
|
||||
}
|
||||
if (scoreMode != null) {
|
||||
builder.field("score_mode", scoreMode);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
|||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.rescore.RescoreBuilder;
|
||||
import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer;
|
||||
import org.elasticsearch.test.integration.AbstractSharedClusterTest;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -280,29 +281,164 @@ public class QueryRescorerTests extends AbstractSharedClusterTest {
|
|||
.setSource("field1", "quick huge brown", "field2", "the quick lazy huge brown fox jumps over the tree").execute()
|
||||
.actionGet();
|
||||
refresh();
|
||||
SearchResponse searchResponse = client()
|
||||
.prepareSearch()
|
||||
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
|
||||
.setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR))
|
||||
.setRescorer(
|
||||
RescoreBuilder.queryRescorer(QueryBuilders.matchPhraseQuery("field1", "the quick brown").slop(2).boost(4.0f))
|
||||
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f)).setRescoreWindow(5).setExplain(true).execute()
|
||||
.actionGet();
|
||||
assertHitCount(searchResponse, 3);
|
||||
assertFirstHit(searchResponse, hasId("1"));
|
||||
assertSecondHit(searchResponse, hasId("2"));
|
||||
assertThirdHit(searchResponse, hasId("3"));
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation(), notNullValue());
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().isMatch(), equalTo(true));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails().length, equalTo(2));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].isMatch(), equalTo(true));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].getDetails()[1].getValue(), equalTo(0.5f));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[1].getValue(), equalTo(0.4f));
|
||||
if (i == 2) {
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].isMatch(), equalTo(false));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[0].getValue(), equalTo(0.0f));
|
||||
{
|
||||
SearchResponse searchResponse = client()
|
||||
.prepareSearch()
|
||||
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
|
||||
.setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR))
|
||||
.setRescorer(
|
||||
RescoreBuilder.queryRescorer(QueryBuilders.matchPhraseQuery("field1", "the quick brown").slop(2).boost(4.0f))
|
||||
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f)).setRescoreWindow(5).setExplain(true).execute()
|
||||
.actionGet();
|
||||
assertHitCount(searchResponse, 3);
|
||||
assertFirstHit(searchResponse, hasId("1"));
|
||||
assertSecondHit(searchResponse, hasId("2"));
|
||||
assertThirdHit(searchResponse, hasId("3"));
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation(), notNullValue());
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().isMatch(), equalTo(true));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails().length, equalTo(2));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].isMatch(), equalTo(true));
|
||||
if (i == 2) {
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getValue(), equalTo(0.5f));
|
||||
} else {
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDescription(), equalTo("sum of:"));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].getDetails()[1].getValue(), equalTo(0.5f));
|
||||
assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[1].getValue(), equalTo(0.4f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
String[] scoreModes = new String[]{ "max", "min", "avg", "total", "multiply", "" };
|
||||
String[] descriptionModes = new String[]{ "max of:", "min of:", "avg of:", "sum of:", "product of:", "sum of:" };
|
||||
for (int i = 0; i < scoreModes.length; i++) {
|
||||
QueryRescorer rescoreQuery = RescoreBuilder.queryRescorer(QueryBuilders.matchQuery("field1", "the quick brown").boost(4.0f))
|
||||
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f);
|
||||
|
||||
if (!"".equals(scoreModes[i])) {
|
||||
rescoreQuery.setScoreMode(scoreModes[i]);
|
||||
}
|
||||
|
||||
SearchResponse searchResponse = client()
|
||||
.prepareSearch()
|
||||
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
|
||||
.setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR))
|
||||
.setRescorer(rescoreQuery).setRescoreWindow(5).setExplain(true).execute()
|
||||
.actionGet();
|
||||
assertHitCount(searchResponse, 3);
|
||||
assertFirstHit(searchResponse, hasId("1"));
|
||||
assertSecondHit(searchResponse, hasId("2"));
|
||||
assertThirdHit(searchResponse, hasId("3"));
|
||||
|
||||
for (int j = 0; j < 3; j++) {
|
||||
assertThat(searchResponse.getHits().getAt(j).explanation().getDescription(), equalTo(descriptionModes[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScoring() throws Exception {
|
||||
client().admin()
|
||||
.indices()
|
||||
.prepareCreate("test")
|
||||
.addMapping(
|
||||
"type1",
|
||||
jsonBuilder().startObject().startObject("type1").startObject("properties").startObject("field1")
|
||||
.field("index", "not_analyzed").field("type", "string").endObject().endObject().endObject().endObject())
|
||||
.setSettings(ImmutableSettings.settingsBuilder()).execute().actionGet();
|
||||
ensureGreen();
|
||||
int numDocs = 1000;
|
||||
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
client().prepareIndex("test", "type1", String.valueOf(i)).setSource("field1", English.intToEnglish(i)).execute().actionGet();
|
||||
}
|
||||
|
||||
flush();
|
||||
optimize(); // make sure we don't have a background merge running
|
||||
refresh();
|
||||
ensureGreen();
|
||||
|
||||
String[] scoreModes = new String[]{ "max", "min", "avg", "total", "multiply", "" };
|
||||
float primaryWeight = 1.1f;
|
||||
float secondaryWeight = 1.6f;
|
||||
|
||||
for (String scoreMode: scoreModes) {
|
||||
for (int i = 0; i < numDocs - 4; i++) {
|
||||
String[] intToEnglish = new String[] { English.intToEnglish(i), English.intToEnglish(i + 1), English.intToEnglish(i + 2), English.intToEnglish(i + 3) };
|
||||
|
||||
QueryRescorer rescoreQuery = RescoreBuilder
|
||||
.queryRescorer(
|
||||
QueryBuilders.boolQuery()
|
||||
.disableCoord(true)
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).script("5.0f"))
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).script("7.0f"))
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).script("0.0f")))
|
||||
.setQueryWeight(primaryWeight)
|
||||
.setRescoreQueryWeight(secondaryWeight);
|
||||
|
||||
if (!"".equals(scoreMode)) {
|
||||
rescoreQuery.setScoreMode(scoreMode);
|
||||
}
|
||||
|
||||
SearchResponse rescored = client()
|
||||
.prepareSearch()
|
||||
.setPreference("test") // ensure we hit the same shards for tie-breaking
|
||||
.setQuery(QueryBuilders.boolQuery()
|
||||
.disableCoord(true)
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).script("2.0f"))
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).script("3.0f"))
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).script("5.0f"))
|
||||
.should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).script("0.2f")))
|
||||
.setFrom(0)
|
||||
.setSize(10)
|
||||
.setRescorer(rescoreQuery)
|
||||
.setRescoreWindow(50).execute().actionGet();
|
||||
|
||||
assertHitCount(rescored, 4);
|
||||
|
||||
if ("total".equals(scoreMode) || "".equals(scoreMode)) {
|
||||
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
|
||||
assertSecondHit(rescored, hasId(String.valueOf(i)));
|
||||
assertThirdHit(rescored, hasId(String.valueOf(i + 2)));
|
||||
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(3.0f * primaryWeight + 7.0f * secondaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(2.0f * primaryWeight + 5.0f * secondaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight + 0.0f * secondaryWeight));
|
||||
} else if ("max".equals(scoreMode)) {
|
||||
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
|
||||
assertSecondHit(rescored, hasId(String.valueOf(i)));
|
||||
assertThirdHit(rescored, hasId(String.valueOf(i + 2)));
|
||||
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(7.0f * secondaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(5.0f * secondaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight));
|
||||
} else if ("min".equals(scoreMode)) {
|
||||
assertFirstHit(rescored, hasId(String.valueOf(i + 2)));
|
||||
assertSecondHit(rescored, hasId(String.valueOf(i + 1)));
|
||||
assertThirdHit(rescored, hasId(String.valueOf(i)));
|
||||
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(5.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(3.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(2.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.0f * secondaryWeight));
|
||||
} else if ("avg".equals(scoreMode)) {
|
||||
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
|
||||
assertSecondHit(rescored, hasId(String.valueOf(i + 2)));
|
||||
assertThirdHit(rescored, hasId(String.valueOf(i)));
|
||||
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo((3.0f * primaryWeight + 7.0f * secondaryWeight) / 2.0f));
|
||||
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(5.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo((2.0f * primaryWeight + 5.0f * secondaryWeight) / 2.0f));
|
||||
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo((0.2f * primaryWeight) / 2.0f));
|
||||
} else if ("multiply".equals(scoreMode)) {
|
||||
assertFirstHit(rescored, hasId(String.valueOf(i + 1)));
|
||||
assertSecondHit(rescored, hasId(String.valueOf(i)));
|
||||
assertThirdHit(rescored, hasId(String.valueOf(i + 2)));
|
||||
assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(3.0f * primaryWeight * 7.0f * secondaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(2.0f * primaryWeight * 5.0f * secondaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight));
|
||||
assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight * 0.0f * secondaryWeight));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue