From e5f2ce0fd644412298b5ef35ebcdcc074b23795c Mon Sep 17 00:00:00 2001 From: Shay Banon Date: Wed, 4 Jan 2012 21:53:26 +0200 Subject: [PATCH] use factor in scripts, so custom score function will work correctly when it multiplies --- .../search/function/BoostScoreFunction.java | 17 ++++++- .../function/FiltersFunctionScoreQuery.java | 51 ++++++++++--------- .../search/function/FunctionScoreQuery.java | 4 +- .../lucene/search/function/ScoreFunction.java | 6 ++- .../index/query/CustomScoreQueryParser.java | 14 ++++- .../customscore/CustomScoreSearchTests.java | 4 +- 6 files changed, 63 insertions(+), 33 deletions(-) diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java index d4fd30c3cf4..99b2bcbdbed 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java @@ -49,13 +49,23 @@ public class BoostScoreFunction implements ScoreFunction { } @Override - public Explanation explain(int docId, Explanation subQueryExpl) { + public float factor(int docId) { + return boost; + } + + @Override + public Explanation explainScore(int docId, Explanation subQueryExpl) { Explanation exp = new Explanation(boost * subQueryExpl.getValue(), "static boost function: product of:"); exp.addDetail(subQueryExpl); exp.addDetail(new Explanation(boost, "boostFactor")); return exp; } + @Override + public Explanation explainFactor(int docId) { + return new Explanation(boost, "boostFactor"); + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -72,4 +82,9 @@ public class BoostScoreFunction implements ScoreFunction { public int hashCode() { return (boost != +0.0f ? Float.floatToIntBits(boost) : 0); } + + @Override + public String toString() { + return "boost[" + boost + "]"; + } } diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java index 2c682f556c0..43263a466a7 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java @@ -34,8 +34,6 @@ import java.util.Set; /** * A query that allows for a pluggable boost function / filter. If it matches the filter, it will * be boosted by the formula. - * - * */ public class FiltersFunctionScoreQuery extends Query { @@ -166,7 +164,7 @@ public class FiltersFunctionScoreQuery extends Query { DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); if (docSet.get(doc)) { filterFunction.function.setNextReader(reader); - Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl); + Explanation functionExplanation = filterFunction.function.explainFactor(doc); float sc = getValue() * functionExplanation.getValue(); Explanation res = new ComplexExplanation(true, sc, "custom score, product of:"); res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); @@ -186,7 +184,7 @@ public class FiltersFunctionScoreQuery extends Query { DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); if (docSet.get(doc)) { filterFunction.function.setNextReader(reader); - Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl); + Explanation functionExplanation = filterFunction.function.explainFactor(doc); float sc = functionExplanation.getValue(); count++; total += sc; @@ -221,6 +219,7 @@ public class FiltersFunctionScoreQuery extends Query { } sc *= getValue(); Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase() + "]"); + res.addDetail(subQueryExpl); for (Explanation explanation : filtersExplanations) { res.addDetail(explanation); } @@ -272,56 +271,58 @@ public class FiltersFunctionScoreQuery extends Query { @Override public float score() throws IOException { int docId = scorer.docID(); - float score = scorer.score(); + float factor = 1.0f; if (scoreMode == ScoreMode.First) { for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - return subQueryWeight * filterFunctions[i].function.score(docId, score); + factor = filterFunctions[i].function.factor(docId); + break; } } } else if (scoreMode == ScoreMode.Max) { - float maxScore = Float.NEGATIVE_INFINITY; + float maxFactor = Float.NEGATIVE_INFINITY; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - maxScore = Math.max(filterFunctions[i].function.score(docId, score), maxScore); + maxFactor = Math.max(filterFunctions[i].function.factor(docId), maxFactor); } } - if (maxScore != Float.NEGATIVE_INFINITY) { - score = maxScore; + if (maxFactor != Float.NEGATIVE_INFINITY) { + factor = maxFactor; } } else if (scoreMode == ScoreMode.Min) { - float minScore = Float.POSITIVE_INFINITY; + float minFactor = Float.POSITIVE_INFINITY; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - minScore = Math.min(filterFunctions[i].function.score(docId, score), minScore); + minFactor = Math.min(filterFunctions[i].function.factor(docId), minFactor); } } - if (minScore != Float.POSITIVE_INFINITY) { - score = minScore; + if (minFactor != Float.POSITIVE_INFINITY) { + factor = minFactor; + } + } else if (scoreMode == ScoreMode.Multiply) { + for (int i = 0; i < filterFunctions.length; i++) { + if (docSets[i].get(docId)) { + factor *= filterFunctions[i].function.factor(docId); + } } } else { // Avg / Total - float totalScore = 0.0f; - float multiplicativeScore = 1.0f; + float totalFactor = 0.0f; int count = 0; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - float tempScore = filterFunctions[i].function.score(docId, score); - totalScore += tempScore; - multiplicativeScore *= tempScore; + totalFactor += filterFunctions[i].function.factor(docId); count++; } } if (count != 0) { - score = totalScore; + factor = totalFactor; if (scoreMode == ScoreMode.Avg) { - score /= count; - } - else if (scoreMode == ScoreMode.Multiply) { - score = multiplicativeScore; + factor /= count; } } } - return subQueryWeight * score; + float score = scorer.score(); + return subQueryWeight * score * factor; } } diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java index 11e2e055ef4..02b59404899 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java @@ -29,8 +29,6 @@ import java.util.Set; /** * A query that allows for a pluggable boost function to be applied to it. - * - * */ public class FunctionScoreQuery extends Query { @@ -117,7 +115,7 @@ public class FunctionScoreQuery extends Query { } function.setNextReader(reader); - Explanation functionExplanation = function.explain(doc, subQueryExpl); + Explanation functionExplanation = function.explainScore(doc, subQueryExpl); float sc = getValue() * functionExplanation.getValue(); Explanation res = new ComplexExplanation(true, sc, "custom score, product of:"); res.addDetail(functionExplanation); diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java index f5dd7437766..b54f140fbfa 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java @@ -31,5 +31,9 @@ public interface ScoreFunction { float score(int docId, float subQueryScore); - Explanation explain(int docId, Explanation subQueryExpl); + float factor(int docId); + + Explanation explainScore(int docId, Explanation subQueryExpl); + + Explanation explainFactor(int docId); } diff --git a/src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java b/src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java index 3741c71afe8..e866f9e8615 100644 --- a/src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java +++ b/src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java @@ -125,13 +125,25 @@ public class CustomScoreQueryParser implements QueryParser { } @Override - public Explanation explain(int docId, Explanation subQueryExpl) { + public float factor(int docId) { + // just the factor, so don't provide _score + script.setNextDocId(docId); + return script.runAsFloat(); + } + + @Override + public Explanation explainScore(int docId, Explanation subQueryExpl) { float score = score(docId, subQueryExpl.getValue()); Explanation exp = new Explanation(score, "script score function: product of:"); exp.addDetail(subQueryExpl); return exp; } + @Override + public Explanation explainFactor(int docId) { + return new Explanation(factor(docId), "scriptFactor"); + } + @Override public String toString() { return "script[" + sScript + "], params [" + params + "]"; diff --git a/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java b/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java index 33b9a4b840d..ea9239a8b96 100644 --- a/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java +++ b/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java @@ -164,8 +164,8 @@ public class CustomScoreSearchTests extends AbstractNodesTests { SearchResponse searchResponse = client.prepareSearch("test") .setQuery(customFiltersScoreQuery(matchAllQuery()) - .add(termFilter("field", "value4"), "_score * 2") - .add(termFilter("field", "value2"), "_score * 3")) + .add(termFilter("field", "value4"), "2") + .add(termFilter("field", "value2"), "3")) .setExplain(true) .execute().actionGet();