From 5022c4659c406569977dd7a66aa33c2570ac949b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20HOURCADE?= Date: Sat, 20 Jul 2013 17:00:04 +0200 Subject: [PATCH] Add a "score_mode" parameter to rescoring query. Default value is "total", possible values are: "max", "min", "avg", "multiply" and "total". - "total": the final score of a document is the sum of the original query score with the rescore query score. - "max": only the highest score count. - "min": only the lowest score is kept (if the document doesn't match the rescore query, the original query score is used). - "avg": average of both scores - "multiply": product of both scores Closes #3258 --- .../search/rescore/QueryRescorer.java | 108 +++++++++-- .../search/rescore/RescoreBuilder.java | 12 ++ .../search/rescore/QueryRescorerTests.java | 180 +++++++++++++++--- 3 files changed, 266 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java b/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java index 57d71333d95..6549064fa2b 100644 --- a/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java +++ b/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java @@ -32,11 +32,68 @@ import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; import java.util.Arrays; import java.util.Set; +import java.lang.Math; final class QueryRescorer implements Rescorer { - + + private static enum ScoreMode { + Avg { + @Override + public float combine(float primary, float secondary) { + return (primary + secondary) / 2; + } + @Override + public String toString() { + return "avg"; + } + }, + Max { + @Override + public float combine(float primary, float secondary) { + return Math.max(primary, secondary); + } + @Override + public String toString() { + return "max"; + } + }, + Min { + @Override + public float combine(float primary, float secondary) { + return Math.min(primary, secondary); + } + @Override + public String toString() { + return "min"; + } + }, + Total { + @Override + public float combine(float primary, float secondary) { + return primary + secondary; + } + @Override + public String toString() { + return "sum"; + } + }, + Multiply { + @Override + public float combine(float primary, float secondary) { + return primary * secondary; + } + @Override + public String toString() { + return "product"; + } + }; + + public abstract float combine(float primary, float secondary); + } + public static final Rescorer INSTANCE = new QueryRescorer(); public static final String NAME = "query"; + @Override public String name() { return NAME; @@ -72,20 +129,21 @@ final class QueryRescorer implements Rescorer { "product of:"); prim.addDetail(primaryExplain); prim.addDetail(new Explanation(primaryWeight, "primaryWeight")); - if (rescoreExplain != null) { - ComplexExplanation sumExpl = new ComplexExplanation(); - sumExpl.setDescription("sum of:"); - sumExpl.addDetail(prim); - sumExpl.setMatch(prim.isMatch()); + if (rescoreExplain != null && rescoreExplain.isMatch()) { float secondaryWeight = rescore.rescoreQueryWeight(); ComplexExplanation sec = new ComplexExplanation(rescoreExplain.isMatch(), rescoreExplain.getValue() * secondaryWeight, "product of:"); sec.addDetail(rescoreExplain); sec.addDetail(new Explanation(secondaryWeight, "secondaryWeight")); - sumExpl.addDetail(sec); - sumExpl.setValue(prim.getValue() + sec.getValue()); - return sumExpl; + ScoreMode scoreMode = rescore.scoreMode(); + ComplexExplanation calcExpl = new ComplexExplanation(); + calcExpl.setDescription(scoreMode + " of:"); + calcExpl.addDetail(prim); + calcExpl.setMatch(prim.isMatch()); + calcExpl.addDetail(sec); + calcExpl.setValue(scoreMode.combine(prim.getValue(), sec.getValue())); + return calcExpl; } else { return prim; } @@ -108,6 +166,21 @@ final class QueryRescorer implements Rescorer { rescoreContext.setQueryWeight(parser.floatValue()); } else if("rescore_query_weight".equals(fieldName)) { rescoreContext.setRescoreQueryWeight(parser.floatValue()); + } else if ("score_mode".equals(fieldName)) { + String sScoreMode = parser.text(); + if ("avg".equals(sScoreMode)) { + rescoreContext.setScoreMode(ScoreMode.Avg); + } else if ("max".equals(sScoreMode)) { + rescoreContext.setScoreMode(ScoreMode.Max); + } else if ("min".equals(sScoreMode)) { + rescoreContext.setScoreMode(ScoreMode.Min); + } else if ("total".equals(sScoreMode)) { + rescoreContext.setScoreMode(ScoreMode.Total); + } else if ("multiply".equals(sScoreMode)) { + rescoreContext.setScoreMode(ScoreMode.Multiply); + } else { + throw new ElasticSearchIllegalArgumentException("[rescore] illegal score_mode [" + sScoreMode + "]"); + } } else { throw new ElasticSearchIllegalArgumentException("rescore doesn't support [" + fieldName + "]"); } @@ -120,11 +193,13 @@ final class QueryRescorer implements Rescorer { public QueryRescoreContext(QueryRescorer rescorer) { super(NAME, 10, rescorer); + this.scoreMode = ScoreMode.Total; } private ParsedQuery parsedQuery; private float queryWeight = 1.0f; private float rescoreQueryWeight = 1.0f; + private ScoreMode scoreMode; public void setParsedQuery(ParsedQuery parsedQuery) { this.parsedQuery = parsedQuery; @@ -142,6 +217,10 @@ final class QueryRescorer implements Rescorer { return rescoreQueryWeight; } + public ScoreMode scoreMode() { + return scoreMode; + } + public void setRescoreQueryWeight(float rescoreQueryWeight) { this.rescoreQueryWeight = rescoreQueryWeight; } @@ -149,6 +228,10 @@ final class QueryRescorer implements Rescorer { public void setQueryWeight(float queryWeight) { this.queryWeight = queryWeight; } + + public void setScoreMode(ScoreMode scoreMode) { + this.scoreMode = scoreMode; + } } @@ -164,9 +247,10 @@ final class QueryRescorer implements Rescorer { int j = 0; float primaryWeight = context.queryWeight(); float secondaryWeight = context.rescoreQueryWeight(); - for (int i = 0; i < primaryDocs.length && j < secondaryDocs.length; i++) { - if (primaryDocs[i].doc == secondaryDocs[j].doc) { - primaryDocs[i].score = (primaryDocs[i].score * primaryWeight) + (secondaryDocs[j++].score * secondaryWeight); + ScoreMode scoreMode = context.scoreMode(); + for (int i = 0; i < primaryDocs.length; i++) { + if (j < secondaryDocs.length && primaryDocs[i].doc == secondaryDocs[j].doc) { + primaryDocs[i].score = scoreMode.combine(primaryDocs[i].score * primaryWeight, secondaryDocs[j++].score * secondaryWeight); } else { primaryDocs[i].score *= primaryWeight; } diff --git a/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java b/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java index 69049332dcc..e9d4d994491 100644 --- a/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java +++ b/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java @@ -84,6 +84,7 @@ public class RescoreBuilder implements ToXContent { private QueryBuilder queryBuilder; private Float rescoreQueryWeight; private Float queryWeight; + private String scoreMode; /** * Creates a new {@link QueryRescorer} instance @@ -109,6 +110,14 @@ public class RescoreBuilder implements ToXContent { return this; } + /** + * Sets the original query score mode. The default is total + */ + public QueryRescorer setScoreMode(String scoreMode) { + this.scoreMode = scoreMode; + return this; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("rescore_query", queryBuilder); @@ -118,6 +127,9 @@ public class RescoreBuilder implements ToXContent { if (rescoreQueryWeight != null) { builder.field("rescore_query_weight", rescoreQueryWeight); } + if (scoreMode != null) { + builder.field("score_mode", scoreMode); + } return builder; } } diff --git a/src/test/java/org/elasticsearch/test/integration/search/rescore/QueryRescorerTests.java b/src/test/java/org/elasticsearch/test/integration/search/rescore/QueryRescorerTests.java index f4bad945717..731ad8515ae 100644 --- a/src/test/java/org/elasticsearch/test/integration/search/rescore/QueryRescorerTests.java +++ b/src/test/java/org/elasticsearch/test/integration/search/rescore/QueryRescorerTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.rescore.RescoreBuilder; +import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer; import org.elasticsearch.test.integration.AbstractSharedClusterTest; import org.junit.Test; @@ -280,29 +281,164 @@ public class QueryRescorerTests extends AbstractSharedClusterTest { .setSource("field1", "quick huge brown", "field2", "the quick lazy huge brown fox jumps over the tree").execute() .actionGet(); refresh(); - SearchResponse searchResponse = client() - .prepareSearch() - .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) - .setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR)) - .setRescorer( - RescoreBuilder.queryRescorer(QueryBuilders.matchPhraseQuery("field1", "the quick brown").slop(2).boost(4.0f)) - .setQueryWeight(0.5f).setRescoreQueryWeight(0.4f)).setRescoreWindow(5).setExplain(true).execute() - .actionGet(); - assertHitCount(searchResponse, 3); - assertFirstHit(searchResponse, hasId("1")); - assertSecondHit(searchResponse, hasId("2")); - assertThirdHit(searchResponse, hasId("3")); - for (int i = 0; i < 3; i++) { - assertThat(searchResponse.getHits().getAt(i).explanation(), notNullValue()); - assertThat(searchResponse.getHits().getAt(i).explanation().isMatch(), equalTo(true)); - assertThat(searchResponse.getHits().getAt(i).explanation().getDetails().length, equalTo(2)); - assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].isMatch(), equalTo(true)); - assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].getDetails()[1].getValue(), equalTo(0.5f)); - assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[1].getValue(), equalTo(0.4f)); - if (i == 2) { - assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].isMatch(), equalTo(false)); - assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[0].getValue(), equalTo(0.0f)); + { + SearchResponse searchResponse = client() + .prepareSearch() + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) + .setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR)) + .setRescorer( + RescoreBuilder.queryRescorer(QueryBuilders.matchPhraseQuery("field1", "the quick brown").slop(2).boost(4.0f)) + .setQueryWeight(0.5f).setRescoreQueryWeight(0.4f)).setRescoreWindow(5).setExplain(true).execute() + .actionGet(); + assertHitCount(searchResponse, 3); + assertFirstHit(searchResponse, hasId("1")); + assertSecondHit(searchResponse, hasId("2")); + assertThirdHit(searchResponse, hasId("3")); + + for (int i = 0; i < 3; i++) { + assertThat(searchResponse.getHits().getAt(i).explanation(), notNullValue()); + assertThat(searchResponse.getHits().getAt(i).explanation().isMatch(), equalTo(true)); + assertThat(searchResponse.getHits().getAt(i).explanation().getDetails().length, equalTo(2)); + assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].isMatch(), equalTo(true)); + if (i == 2) { + assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getValue(), equalTo(0.5f)); + } else { + assertThat(searchResponse.getHits().getAt(i).explanation().getDescription(), equalTo("sum of:")); + assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[0].getDetails()[1].getValue(), equalTo(0.5f)); + assertThat(searchResponse.getHits().getAt(i).explanation().getDetails()[1].getDetails()[1].getValue(), equalTo(0.4f)); + } + } + } + + String[] scoreModes = new String[]{ "max", "min", "avg", "total", "multiply", "" }; + String[] descriptionModes = new String[]{ "max of:", "min of:", "avg of:", "sum of:", "product of:", "sum of:" }; + for (int i = 0; i < scoreModes.length; i++) { + QueryRescorer rescoreQuery = RescoreBuilder.queryRescorer(QueryBuilders.matchQuery("field1", "the quick brown").boost(4.0f)) + .setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); + + if (!"".equals(scoreModes[i])) { + rescoreQuery.setScoreMode(scoreModes[i]); + } + + SearchResponse searchResponse = client() + .prepareSearch() + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) + .setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(MatchQueryBuilder.Operator.OR)) + .setRescorer(rescoreQuery).setRescoreWindow(5).setExplain(true).execute() + .actionGet(); + assertHitCount(searchResponse, 3); + assertFirstHit(searchResponse, hasId("1")); + assertSecondHit(searchResponse, hasId("2")); + assertThirdHit(searchResponse, hasId("3")); + + for (int j = 0; j < 3; j++) { + assertThat(searchResponse.getHits().getAt(j).explanation().getDescription(), equalTo(descriptionModes[i])); + } + } + } + + @Test + public void testScoring() throws Exception { + client().admin() + .indices() + .prepareCreate("test") + .addMapping( + "type1", + jsonBuilder().startObject().startObject("type1").startObject("properties").startObject("field1") + .field("index", "not_analyzed").field("type", "string").endObject().endObject().endObject().endObject()) + .setSettings(ImmutableSettings.settingsBuilder()).execute().actionGet(); + ensureGreen(); + int numDocs = 1000; + + for (int i = 0; i < numDocs; i++) { + client().prepareIndex("test", "type1", String.valueOf(i)).setSource("field1", English.intToEnglish(i)).execute().actionGet(); + } + + flush(); + optimize(); // make sure we don't have a background merge running + refresh(); + ensureGreen(); + + String[] scoreModes = new String[]{ "max", "min", "avg", "total", "multiply", "" }; + float primaryWeight = 1.1f; + float secondaryWeight = 1.6f; + + for (String scoreMode: scoreModes) { + for (int i = 0; i < numDocs - 4; i++) { + String[] intToEnglish = new String[] { English.intToEnglish(i), English.intToEnglish(i + 1), English.intToEnglish(i + 2), English.intToEnglish(i + 3) }; + + QueryRescorer rescoreQuery = RescoreBuilder + .queryRescorer( + QueryBuilders.boolQuery() + .disableCoord(true) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).script("5.0f")) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).script("7.0f")) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).script("0.0f"))) + .setQueryWeight(primaryWeight) + .setRescoreQueryWeight(secondaryWeight); + + if (!"".equals(scoreMode)) { + rescoreQuery.setScoreMode(scoreMode); + } + + SearchResponse rescored = client() + .prepareSearch() + .setPreference("test") // ensure we hit the same shards for tie-breaking + .setQuery(QueryBuilders.boolQuery() + .disableCoord(true) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).script("2.0f")) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).script("3.0f")) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).script("5.0f")) + .should(QueryBuilders.customScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).script("0.2f"))) + .setFrom(0) + .setSize(10) + .setRescorer(rescoreQuery) + .setRescoreWindow(50).execute().actionGet(); + + assertHitCount(rescored, 4); + + if ("total".equals(scoreMode) || "".equals(scoreMode)) { + assertFirstHit(rescored, hasId(String.valueOf(i + 1))); + assertSecondHit(rescored, hasId(String.valueOf(i))); + assertThirdHit(rescored, hasId(String.valueOf(i + 2))); + assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(3.0f * primaryWeight + 7.0f * secondaryWeight)); + assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(2.0f * primaryWeight + 5.0f * secondaryWeight)); + assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight + 0.0f * secondaryWeight)); + } else if ("max".equals(scoreMode)) { + assertFirstHit(rescored, hasId(String.valueOf(i + 1))); + assertSecondHit(rescored, hasId(String.valueOf(i))); + assertThirdHit(rescored, hasId(String.valueOf(i + 2))); + assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(7.0f * secondaryWeight)); + assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(5.0f * secondaryWeight)); + assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight)); + } else if ("min".equals(scoreMode)) { + assertFirstHit(rescored, hasId(String.valueOf(i + 2))); + assertSecondHit(rescored, hasId(String.valueOf(i + 1))); + assertThirdHit(rescored, hasId(String.valueOf(i))); + assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(5.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(3.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(2.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.0f * secondaryWeight)); + } else if ("avg".equals(scoreMode)) { + assertFirstHit(rescored, hasId(String.valueOf(i + 1))); + assertSecondHit(rescored, hasId(String.valueOf(i + 2))); + assertThirdHit(rescored, hasId(String.valueOf(i))); + assertThat(rescored.getHits().getHits()[0].getScore(), equalTo((3.0f * primaryWeight + 7.0f * secondaryWeight) / 2.0f)); + assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(5.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[2].getScore(), equalTo((2.0f * primaryWeight + 5.0f * secondaryWeight) / 2.0f)); + assertThat(rescored.getHits().getHits()[3].getScore(), equalTo((0.2f * primaryWeight) / 2.0f)); + } else if ("multiply".equals(scoreMode)) { + assertFirstHit(rescored, hasId(String.valueOf(i + 1))); + assertSecondHit(rescored, hasId(String.valueOf(i))); + assertThirdHit(rescored, hasId(String.valueOf(i + 2))); + assertThat(rescored.getHits().getHits()[0].getScore(), equalTo(3.0f * primaryWeight * 7.0f * secondaryWeight)); + assertThat(rescored.getHits().getHits()[1].getScore(), equalTo(2.0f * primaryWeight * 5.0f * secondaryWeight)); + assertThat(rescored.getHits().getHits()[2].getScore(), equalTo(5.0f * primaryWeight)); + assertThat(rescored.getHits().getHits()[3].getScore(), equalTo(0.2f * primaryWeight * 0.0f * secondaryWeight)); + } } } }