From d93bc02309f2523e8ba39b3883d98d83a9eee3f8 Mon Sep 17 00:00:00 2001 From: Shay Banon Date: Thu, 4 Aug 2011 03:31:14 +0300 Subject: [PATCH] Query DSL: custom_filters_score - add score_mode to control filters matching scoring, closes #1205. --- .../function/FiltersFunctionScoreQuery.java | 109 +++++++++++++++--- .../query/CustomFiltersScoreQueryBuilder.java | 11 ++ .../query/CustomFiltersScoreQueryParser.java | 16 ++- .../customscore/CustomScoreSearchTests.java | 53 ++++++++- 4 files changed, 168 insertions(+), 21 deletions(-) diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java b/modules/elasticsearch/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java index 1d313c41b0f..61adb6a8e22 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.lucene.docset.DocSet; import org.elasticsearch.common.lucene.docset.DocSets; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Set; @@ -73,12 +74,16 @@ public class FiltersFunctionScoreQuery extends Query { } } + public static enum ScoreMode {First, Avg, Max, Total} + Query subQuery; final FilterFunction[] filterFunctions; + final ScoreMode scoreMode; DocSet[] docSets; - public FiltersFunctionScoreQuery(Query subQuery, FilterFunction[] filterFunctions) { + public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions) { this.subQuery = subQuery; + this.scoreMode = scoreMode; this.filterFunctions = filterFunctions; this.docSets = new DocSet[filterFunctions.length]; } @@ -151,7 +156,7 @@ public class FiltersFunctionScoreQuery extends Query { filterFunction.function.setNextReader(reader); docSets[i] = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); } - return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filterFunctions, docSets); + return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, scoreMode, filterFunctions, docSets); } @Override @@ -161,16 +166,59 @@ public class FiltersFunctionScoreQuery extends Query { return subQueryExpl; } - for (FilterFunction filterFunction : filterFunctions) { - DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); - if (docSet.get(doc)) { - filterFunction.function.setNextReader(reader); - Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl); - float sc = getValue() * functionExplanation.getValue(); - Explanation res = new ComplexExplanation(true, sc, "custom score, product of:"); - res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); - res.addDetail(functionExplanation); - res.addDetail(new Explanation(getValue(), "queryBoost")); + if (scoreMode == ScoreMode.First) { + for (FilterFunction filterFunction : filterFunctions) { + DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); + if (docSet.get(doc)) { + filterFunction.function.setNextReader(reader); + Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl); + float sc = getValue() * functionExplanation.getValue(); + Explanation res = new ComplexExplanation(true, sc, "custom score, product of:"); + res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); + res.addDetail(functionExplanation); + res.addDetail(new Explanation(getValue(), "queryBoost")); + return res; + } + } + } else { + int count = 0; + float total = 0; + float max = Float.NEGATIVE_INFINITY; + ArrayList filtersExplanations = new ArrayList(); + for (FilterFunction filterFunction : filterFunctions) { + DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); + if (docSet.get(doc)) { + filterFunction.function.setNextReader(reader); + Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl); + float sc = functionExplanation.getValue(); + count++; + total += sc; + max = Math.max(sc, max); + Explanation res = new ComplexExplanation(true, sc, "custom score, product of:"); + res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); + res.addDetail(functionExplanation); + res.addDetail(new Explanation(getValue(), "queryBoost")); + filtersExplanations.add(res); + } + } + if (count > 0) { + float sc = 0; + switch (scoreMode) { + case Avg: + sc = total / count; + break; + case Max: + sc = max; + break; + case Total: + sc = total; + break; + } + sc *= getValue(); + Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase() + "]"); + for (Explanation explanation : filtersExplanations) { + res.addDetail(explanation); + } return res; } } @@ -188,13 +236,15 @@ public class FiltersFunctionScoreQuery extends Query { private final float subQueryWeight; private final Scorer scorer; private final FilterFunction[] filterFunctions; + private final ScoreMode scoreMode; private final DocSet[] docSets; private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer, - FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException { + ScoreMode scoreMode, FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException { super(similarity); this.subQueryWeight = w.getValue(); this.scorer = scorer; + this.scoreMode = scoreMode; this.filterFunctions = filterFunctions; this.docSets = docSets; } @@ -218,9 +268,36 @@ public class FiltersFunctionScoreQuery extends Query { public float score() throws IOException { int docId = scorer.docID(); float score = scorer.score(); - for (int i = 0; i < filterFunctions.length; i++) { - if (docSets[i].get(docId)) { - return subQueryWeight * filterFunctions[i].function.score(docId, score); + if (scoreMode == ScoreMode.First) { + for (int i = 0; i < filterFunctions.length; i++) { + if (docSets[i].get(docId)) { + return subQueryWeight * filterFunctions[i].function.score(docId, score); + } + } + } else if (scoreMode == ScoreMode.Max) { + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < filterFunctions.length; i++) { + if (docSets[i].get(docId)) { + maxScore = Math.max(filterFunctions[i].function.score(docId, score), maxScore); + } + } + if (maxScore != Float.NEGATIVE_INFINITY) { + score = maxScore; + } + } else { // Avg / Total + float totalScore = 0.0f; + int count = 0; + for (int i = 0; i < filterFunctions.length; i++) { + if (docSets[i].get(docId)) { + totalScore += filterFunctions[i].function.score(docId, score); + count++; + } + } + if (count != 0) { + score = totalScore; + if (scoreMode == ScoreMode.Avg) { + score /= count; + } } } return subQueryWeight * score; diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java index 798e0e59a06..c4bb3fdda0c 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java @@ -42,6 +42,8 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder { private Map params = null; + private String scoreMode; + private ArrayList filters = new ArrayList(); private ArrayList scripts = new ArrayList(); private TFloatArrayList boosts = new TFloatArrayList(); @@ -64,6 +66,11 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder { return this; } + public CustomFiltersScoreQueryBuilder scoreMode(String scoreMode) { + this.scoreMode = scoreMode; + return this; + } + /** * Sets the language of the script. */ @@ -124,6 +131,10 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder { } builder.endArray(); + if (scoreMode != null) { + builder.field("score_mode", scoreMode); + } + if (lang != null) { builder.field("lang", lang); } diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java index c4d59e1dbbe..b32d5847a2e 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java @@ -58,6 +58,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser { String scriptLang = null; Map vars = null; + FiltersFunctionScoreQuery.ScoreMode scoreMode = FiltersFunctionScoreQuery.ScoreMode.First; ArrayList filters = new ArrayList(); ArrayList scripts = new ArrayList(); TFloatArrayList boosts = new TFloatArrayList(); @@ -110,6 +111,19 @@ public class CustomFiltersScoreQueryParser implements QueryParser { scriptLang = parser.text(); } else if ("boost".equals(currentFieldName)) { boost = parser.floatValue(); + } else if ("score_mode".equals(currentFieldName) || "scoreMode".equals(currentFieldName)) { + String sScoreMode = parser.text(); + if ("avg".equals(sScoreMode)) { + scoreMode = FiltersFunctionScoreQuery.ScoreMode.Avg; + } else if ("max".equals(sScoreMode)) { + scoreMode = FiltersFunctionScoreQuery.ScoreMode.Max; + } else if ("total".equals(sScoreMode)) { + scoreMode = FiltersFunctionScoreQuery.ScoreMode.Total; + } else if ("first".equals(sScoreMode)) { + scoreMode = FiltersFunctionScoreQuery.ScoreMode.First; + } else { + throw new QueryParsingException(parseContext.index(), "illegal score_mode for nested query [" + sScoreMode + "]"); + } } } } @@ -136,7 +150,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser { } filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction); } - FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions); + FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions); functionScoreQuery.setBoost(boost); return functionScoreQuery; } diff --git a/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java b/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java index 7439beea4a9..fbfd96e4e2e 100644 --- a/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java +++ b/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java @@ -148,10 +148,10 @@ public class CustomScoreSearchTests extends AbstractNodesTests { @Test public void testCustomFiltersScore() throws Exception { client.admin().indices().prepareDelete().execute().actionGet(); - client.prepareIndex("test", "type", "1").setSource("field", "value1").execute().actionGet(); - client.prepareIndex("test", "type", "2").setSource("field", "value2").execute().actionGet(); - client.prepareIndex("test", "type", "3").setSource("field", "value3").execute().actionGet(); - client.prepareIndex("test", "type", "4").setSource("field", "value4").execute().actionGet(); + client.prepareIndex("test", "type", "1").setSource("field", "value1", "color", "red").execute().actionGet(); + client.prepareIndex("test", "type", "2").setSource("field", "value2", "color", "blue").execute().actionGet(); + client.prepareIndex("test", "type", "3").setSource("field", "value3", "color", "red").execute().actionGet(); + client.prepareIndex("test", "type", "4").setSource("field", "value4", "color", "blue").execute().actionGet(); client.admin().indices().prepareRefresh().execute().actionGet(); @@ -194,5 +194,50 @@ public class CustomScoreSearchTests extends AbstractNodesTests { assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f)); assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3"))); assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f)); + + searchResponse = client.prepareSearch("test") + .setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("total") + .add(termFilter("field", "value4"), 2) + .add(termFilter("field", "value1"), 3) + .add(termFilter("color", "red"), 5)) + .setExplain(true) + .execute().actionGet(); + + assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0)); + assertThat(searchResponse.hits().totalHits(), equalTo(4l)); + assertThat(searchResponse.hits().getAt(0).id(), equalTo("1")); + assertThat(searchResponse.hits().getAt(0).score(), equalTo(8.0f)); + logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation()); + + searchResponse = client.prepareSearch("test") + .setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("max") + .add(termFilter("field", "value4"), 2) + .add(termFilter("field", "value1"), 3) + .add(termFilter("color", "red"), 5)) + .setExplain(true) + .execute().actionGet(); + + assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0)); + assertThat(searchResponse.hits().totalHits(), equalTo(4l)); + assertThat(searchResponse.hits().getAt(0).id(), equalTo("1")); + assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f)); + logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation()); + + searchResponse = client.prepareSearch("test") + .setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("avg") + .add(termFilter("field", "value4"), 2) + .add(termFilter("field", "value1"), 3) + .add(termFilter("color", "red"), 5)) + .setExplain(true) + .execute().actionGet(); + + assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0)); + assertThat(searchResponse.hits().totalHits(), equalTo(4l)); + assertThat(searchResponse.hits().getAt(0).id(), equalTo("3")); + assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f)); + logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation()); + assertThat(searchResponse.hits().getAt(1).id(), equalTo("1")); + assertThat(searchResponse.hits().getAt(1).score(), equalTo(4.0f)); + logger.info("--> Hit[1] {} Explanation {}", searchResponse.hits().getAt(1).id(), searchResponse.hits().getAt(1).explanation()); } } \ No newline at end of file