From b007af1f464f49849705356a2eea212ed68643ac Mon Sep 17 00:00:00 2001 From: Britta Weber Date: Wed, 7 Aug 2013 17:39:51 +0200 Subject: [PATCH] Fix inconsistent usage of ScriptScoreFunction in FiltersFunctionScoreQuery This commit fixes inconsistencies in `function_score` and `filters_function_score` using scripts, see issue #3464 The method 'ScoreFunction.factor(docId)' is removed completely, since the name suggests that this method actually computes a factor which was not the case. Multiplying the computed score is now handled by 'FiltersFunctionScoreQuery' and 'FunctionScoreQuery' and not implicitely performed in 'ScoreFunction.factor(docId, subQueryScore)' as was the case for 'BoostScoreFunction' and 'DecayScoreFunctions'. This commit also fixes the explain function for FiltersFunctionScoreQuery. Here, the influence of the maxBoost was never printed. Furthermore, the queryBoost was printed as beeing multiplied to the filter score. Closes #3464 --- .../search/function/BoostScoreFunction.java | 18 +- .../search/function/CombineFunction.java | 90 +++++++ .../function/FiltersFunctionScoreQuery.java | 220 +++++++++--------- .../search/function/FunctionScoreQuery.java | 38 ++- .../search/function/RandomScoreFunction.java | 16 +- .../lucene/search/function/ScoreFunction.java | 19 +- .../search/function/ScriptScoreFunction.java | 19 +- .../functionscore/DecayFunctionParser.java | 23 +- .../customscore/CustomScoreSearchTests.java | 37 ++- 9 files changed, 257 insertions(+), 223 deletions(-) create mode 100644 src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java index c62605ec1f5..c06c50c96d2 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java @@ -25,11 +25,12 @@ import org.apache.lucene.search.Explanation; /** * */ -public class BoostScoreFunction implements ScoreFunction { +public class BoostScoreFunction extends ScoreFunction { private final float boost; public BoostScoreFunction(float boost) { + super(CombineFunction.MULT); this.boost = boost; } @@ -41,30 +42,19 @@ public class BoostScoreFunction implements ScoreFunction { public void setNextReader(AtomicReaderContext context) { // nothing to do here... } - + @Override public double score(int docId, float subQueryScore) { - return subQueryScore * boost; - } - - @Override - public double factor(int docId) { return boost; } @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { - Explanation exp = new Explanation(boost * subQueryExpl.getValue(), "static boost function: product of:"); - exp.addDetail(subQueryExpl); + Explanation exp = new Explanation(boost, "static boost factor"); exp.addDetail(new Explanation(boost, "boostFactor")); return exp; } - @Override - public Explanation explainFactor(int docId) { - return new Explanation(boost, "boostFactor"); - } - @Override public boolean equals(Object o) { if (this == o) diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java new file mode 100644 index 00000000000..ec93746c879 --- /dev/null +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java @@ -0,0 +1,90 @@ +/* + * Licensed to ElasticSearch and Shay Banon under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. ElasticSearch licenses this + * file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.lucene.search.function; + +import org.apache.lucene.search.ComplexExplanation; +import org.apache.lucene.search.Explanation; + +public enum CombineFunction { + MULT { + @Override + public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) { + return toFloat(queryBoost * queryScore * Math.min(funcScore, maxBoost)); + } + + @Override + public String getName() { + return "mult"; + } + + @Override + public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) { + float score = queryBoost * Math.min(funcExpl.getValue(), maxBoost) * queryExpl.getValue(); + ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:"); + res.addDetail(queryExpl); + ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of"); + minExpl.addDetail(funcExpl); + minExpl.addDetail(new Explanation(maxBoost, "maxBoost")); + res.addDetail(minExpl); + res.addDetail(new Explanation(queryBoost, "queryBoost")); + return res; + } + }, + PLAIN { + @Override + public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) { + return toFloat(queryBoost * Math.min(funcScore, maxBoost)); + } + + @Override + public String getName() { + return "plain"; + } + + @Override + public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) { + float score = queryBoost * Math.min(funcExpl.getValue(), maxBoost); + ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:"); + ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of"); + minExpl.addDetail(funcExpl); + minExpl.addDetail(new Explanation(maxBoost, "maxBoost")); + res.addDetail(minExpl); + res.addDetail(new Explanation(queryBoost, "queryBoost")); + return res; + } + + }; + + public abstract float combine(double queryBoost, double queryScore, double funcScore, double maxBoost); + + public abstract String getName(); + + public static float toFloat(double input) { + assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input); + return (float) input; + } + + private static double deviation(double input) { // only with assert! + float floatVersion = (float)input; + return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d-(floatVersion) / input; + } + + public abstract ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost); +} diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java index 708c3aa0b8b..b4ba7fb9c4b 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java @@ -28,14 +28,11 @@ import org.apache.lucene.util.ToStringUtils; import org.elasticsearch.common.lucene.docset.DocIdSets; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Locale; -import java.util.Set; +import java.util.*; /** - * A query that allows for a pluggable boost function / filter. If it matches the filter, it will - * be boosted by the formula. + * A query that allows for a pluggable boost function / filter. If it matches + * the filter, it will be boosted by the formula. */ public class FiltersFunctionScoreQuery extends Query { @@ -50,13 +47,17 @@ public class FiltersFunctionScoreQuery extends Query { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; FilterFunction that = (FilterFunction) o; - if (filter != null ? !filter.equals(that.filter) : that.filter != null) return false; - if (function != null ? !function.equals(that.function) : that.function != null) return false; + if (filter != null ? !filter.equals(that.filter) : that.filter != null) + return false; + if (function != null ? !function.equals(that.function) : that.function != null) + return false; return true; } @@ -69,20 +70,29 @@ public class FiltersFunctionScoreQuery extends Query { } } - public static enum ScoreMode {First, Avg, Max, Total, Min, Multiply} + public static enum ScoreMode { + First, Avg, Max, Total, Min, Multiply + } Query subQuery; final FilterFunction[] filterFunctions; final ScoreMode scoreMode; final float maxBoost; + protected CombineFunction combineFunction; + public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost) { this.subQuery = subQuery; this.scoreMode = scoreMode; this.filterFunctions = filterFunctions; this.maxBoost = maxBoost; + combineFunction = CombineFunction.MULT; } + public FiltersFunctionScoreQuery setCombineFunction(CombineFunction combineFunction){ + this.combineFunction = combineFunction; + return this; + } public Query getSubQuery() { return subQuery; } @@ -94,7 +104,8 @@ public class FiltersFunctionScoreQuery extends Query { @Override public Query rewrite(IndexReader reader) throws IOException { Query newQ = subQuery.rewrite(reader); - if (newQ == subQuery) return this; + if (newQ == subQuery) + return this; FiltersFunctionScoreQuery bq = (FiltersFunctionScoreQuery) this.clone(); bq.subQuery = newQ; return bq; @@ -148,107 +159,88 @@ public class FiltersFunctionScoreQuery extends Query { filterFunction.function.setNextReader(context); docSets[i] = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, acceptDocs)); } - return new CustomBoostFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets); + return new CustomBoostFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets, combineFunction); } @Override public Explanation explain(AtomicReaderContext context, int doc) throws IOException { + Explanation subQueryExpl = subQueryWeight.explain(context, doc); if (!subQueryExpl.isMatch()) { return subQueryExpl; } - - if (scoreMode == ScoreMode.First) { - for (FilterFunction filterFunction : filterFunctions) { - Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); - if (docSet.get(doc)) { - filterFunction.function.setNextReader(context); - Explanation functionExplanation = filterFunction.function.explainFactor(doc); - double factor = functionExplanation.getValue(); - if (factor > maxBoost) { - factor = maxBoost; - } - float sc = FunctionScoreQuery.toFloat(getBoost() * factor); - Explanation filterExplanation = new ComplexExplanation(true, sc, "function score, product of:"); - filterExplanation.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); - filterExplanation.addDetail(functionExplanation); - filterExplanation.addDetail(new Explanation(getBoost(), "queryBoost")); - - // top level score = subquery.score * filter.score (this already has the query boost) - float topLevelScore = subQueryExpl.getValue() * sc; - Explanation topLevel = new ComplexExplanation(true, topLevelScore, "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); - topLevel.addDetail(subQueryExpl); - topLevel.addDetail(filterExplanation); - return topLevel; - } - } - } else { - int count = 0; - float total = 0; - float multiply = 1; - double max = Double.NEGATIVE_INFINITY; - double min = Double.POSITIVE_INFINITY; - ArrayList filtersExplanations = new ArrayList(); - for (FilterFunction filterFunction : filterFunctions) { - Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); - if (docSet.get(doc)) { - filterFunction.function.setNextReader(context); - Explanation functionExplanation = filterFunction.function.explainFactor(doc); - double factor = functionExplanation.getValue(); - count++; - total += factor; - multiply *= factor; - max = Math.max(factor, max); - min = Math.min(factor, min); - Explanation res = new ComplexExplanation(true, FunctionScoreQuery.toFloat(factor), "function score, product of:"); - res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); - res.addDetail(functionExplanation); - res.addDetail(new Explanation(getBoost(), "queryBoost")); - filtersExplanations.add(res); - } - } - if (count > 0) { - double factor = 0; - switch (scoreMode) { - case Avg: - factor = total / count; - break; - case Max: - factor = max; - break; - case Min: - factor = min; - break; - case Total: - factor = total; - break; - case Multiply: - factor = multiply; - break; - } - - if (factor > maxBoost) { - factor = maxBoost; - } - float sc = FunctionScoreQuery.toFloat(factor * subQueryExpl.getValue() * getBoost()); - Explanation res = new ComplexExplanation(true, sc, "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); - res.addDetail(subQueryExpl); - for (Explanation explanation : filtersExplanations) { - res.addDetail(explanation); - } - return res; + // First: Gather explanations for all filters + List filterExplanations = new ArrayList(); + for (FilterFunction filterFunction : filterFunctions) { + Bits docSet = DocIdSets.toSafeBits(context.reader(), + filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); + if (docSet.get(doc)) { + filterFunction.function.setNextReader(context); + Explanation functionExplanation = filterFunction.function.explainScore(doc, subQueryExpl); + double factor = functionExplanation.getValue(); + float sc = CombineFunction.toFloat(factor); + ComplexExplanation filterExplanation = new ComplexExplanation(true, sc, "function score, product of:"); + filterExplanation.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); + filterExplanation.addDetail(functionExplanation); + filterExplanations.add(filterExplanation); } } + if (filterExplanations.size() == 0) { + float sc = getBoost() * subQueryExpl.getValue(); + Explanation res = new ComplexExplanation(true, sc, "function score, no filter match, product of:"); + res.addDetail(subQueryExpl); + res.addDetail(new Explanation(getBoost(), "queryBoost")); + return res; + } - float sc = getBoost() * subQueryExpl.getValue(); - Explanation res = new ComplexExplanation(true, sc, "custom score, no filter match, product of:"); - res.addDetail(subQueryExpl); - res.addDetail(new Explanation(getBoost(), "queryBoost")); - return res; + // Second: Compute the factor that would have been computed by the + // filters + double factor = 1.0; + switch (scoreMode) { + case First: + + factor = filterExplanations.get(0).getValue(); + break; + case Max: + double maxFactor = Double.NEGATIVE_INFINITY; + for (int i = 0; i < filterExplanations.size(); i++) { + factor = Math.max(filterExplanations.get(i).getValue(), maxFactor); + } + break; + case Min: + double minFactor = Double.POSITIVE_INFINITY; + for (int i = 0; i < filterExplanations.size(); i++) { + factor = Math.min(filterExplanations.get(i).getValue(), minFactor); + } + break; + case Multiply: + for (int i = 0; i < filterExplanations.size(); i++) { + factor *= filterExplanations.get(i).getValue(); + } + break; + default: // Avg / Total + double totalFactor = 0.0f; + int count = 0; + for (int i = 0; i < filterExplanations.size(); i++) { + totalFactor += filterExplanations.get(i).getValue(); + count++; + } + if (count != 0) { + factor = totalFactor; + if (scoreMode == ScoreMode.Avg) { + factor /= count; + } + } + } + ComplexExplanation factorExplanaition = new ComplexExplanation(true, CombineFunction.toFloat(factor), + "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); + for (int i = 0; i < filterExplanations.size(); i++) { + factorExplanaition.addDetail(filterExplanations.get(i)); + } + return combineFunction.explain(getBoost(), subQueryExpl, factorExplanaition, maxBoost); } } - static class CustomBoostFactorScorer extends Scorer { private final float subQueryBoost; @@ -257,9 +249,10 @@ public class FiltersFunctionScoreQuery extends Query { private final ScoreMode scoreMode; private final float maxBoost; private final Bits[] docSets; + private final CombineFunction scoreCombiner; - private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, - FilterFunction[] filterFunctions, float maxBoost, Bits[] docSets) throws IOException { + private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions, + float maxBoost, Bits[] docSets, CombineFunction scoreCombiner) throws IOException { super(w); this.subQueryBoost = w.getQuery().getBoost(); this.scorer = scorer; @@ -267,6 +260,7 @@ public class FiltersFunctionScoreQuery extends Query { this.filterFunctions = filterFunctions; this.maxBoost = maxBoost; this.docSets = docSets; + this.scoreCombiner = scoreCombiner; } @Override @@ -288,10 +282,11 @@ public class FiltersFunctionScoreQuery extends Query { public float score() throws IOException { int docId = scorer.docID(); double factor = 1.0f; + float subQueryScore = scorer.score(); if (scoreMode == ScoreMode.First) { for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - factor = filterFunctions[i].function.factor(docId); + factor = filterFunctions[i].function.score(docId, subQueryScore); break; } } @@ -299,7 +294,7 @@ public class FiltersFunctionScoreQuery extends Query { double maxFactor = Double.NEGATIVE_INFINITY; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - maxFactor = Math.max(filterFunctions[i].function.factor(docId), maxFactor); + maxFactor = Math.max(filterFunctions[i].function.score(docId, subQueryScore), maxFactor); } } if (maxFactor != Float.NEGATIVE_INFINITY) { @@ -309,7 +304,7 @@ public class FiltersFunctionScoreQuery extends Query { double minFactor = Double.POSITIVE_INFINITY; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - minFactor = Math.min(filterFunctions[i].function.factor(docId), minFactor); + minFactor = Math.min(filterFunctions[i].function.score(docId, subQueryScore), minFactor); } } if (minFactor != Float.POSITIVE_INFINITY) { @@ -318,7 +313,7 @@ public class FiltersFunctionScoreQuery extends Query { } else if (scoreMode == ScoreMode.Multiply) { for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - factor *= filterFunctions[i].function.factor(docId); + factor *= filterFunctions[i].function.score(docId, subQueryScore); } } } else { // Avg / Total @@ -326,7 +321,7 @@ public class FiltersFunctionScoreQuery extends Query { int count = 0; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - totalFactor += filterFunctions[i].function.factor(docId); + totalFactor += filterFunctions[i].function.score(docId, subQueryScore); count++; } } @@ -337,11 +332,7 @@ public class FiltersFunctionScoreQuery extends Query { } } } - if (factor > maxBoost) { - factor = maxBoost; - } - float score = scorer.score(); - return FunctionScoreQuery.toFloat(subQueryBoost * score * factor); + return scoreCombiner.combine(subQueryBoost, subQueryScore, factor, maxBoost); } @Override @@ -355,10 +346,9 @@ public class FiltersFunctionScoreQuery extends Query { } } - public String toString(String field) { StringBuilder sb = new StringBuilder(); - sb.append("custom score (").append(subQuery.toString(field)).append(", functions: ["); + sb.append("function score (").append(subQuery.toString(field)).append(", functions: ["); for (FilterFunction filterFunction : filterFunctions) { sb.append("{filter(").append(filterFunction.filter).append("), function [").append(filterFunction.function).append("]}"); } @@ -368,7 +358,8 @@ public class FiltersFunctionScoreQuery extends Query { } public boolean equals(Object o) { - if (getClass() != o.getClass()) return false; + if (getClass() != o.getClass()) + return false; FiltersFunctionScoreQuery other = (FiltersFunctionScoreQuery) o; if (this.getBoost() != other.getBoost()) return false; @@ -382,4 +373,3 @@ public class FiltersFunctionScoreQuery extends Query { return subQuery.hashCode() + 31 * Arrays.hashCode(filterFunctions) ^ Float.floatToIntBits(getBoost()); } } - diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java index b2ef9ca9c0a..251fc80ec54 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java @@ -37,12 +37,18 @@ public class FunctionScoreQuery extends Query { Query subQuery; final ScoreFunction function; float maxBoost = Float.MAX_VALUE; - + CombineFunction combineFunction; + public FunctionScoreQuery(Query subQuery, ScoreFunction function) { this.subQuery = subQuery; this.function = function; + this.combineFunction = function.getDefaultScoreCombiner(); } + public void setCombineFunction(CombineFunction combineFunction) { + this.combineFunction = combineFunction; + } + public void setMaxBoost(float maxBoost) { this.maxBoost = maxBoost; } @@ -112,7 +118,7 @@ public class FunctionScoreQuery extends Query { return null; } function.setNextReader(context); - return new CustomBoostFactorScorer(this, subQueryScorer, function, maxBoost); + return new CustomBoostFactorScorer(this, subQueryScorer, function, maxBoost, combineFunction); } @Override @@ -121,14 +127,9 @@ public class FunctionScoreQuery extends Query { if (!subQueryExpl.isMatch()) { return subQueryExpl; } - function.setNextReader(context); Explanation functionExplanation = function.explainScore(doc, subQueryExpl); - float sc = getBoost() * functionExplanation.getValue(); - Explanation res = new ComplexExplanation(true, sc, "function score, product of:"); - res.addDetail(functionExplanation); - res.addDetail(new Explanation(getBoost(), "queryBoost")); - return res; + return combineFunction.explain(getBoost(), subQueryExpl, functionExplanation, maxBoost); } } @@ -138,14 +139,16 @@ public class FunctionScoreQuery extends Query { private final Scorer scorer; private final ScoreFunction function; private final float maxBoost; + private final CombineFunction scoreCombiner; - private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost) + private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost, CombineFunction scoreCombiner) throws IOException { super(w); this.subQueryBoost = w.getQuery().getBoost(); this.scorer = scorer; this.function = function; this.maxBoost = maxBoost; + this.scoreCombiner = scoreCombiner; } @Override @@ -165,8 +168,8 @@ public class FunctionScoreQuery extends Query { @Override public float score() throws IOException { - double factor = function.score(scorer.docID(), scorer.score()); - return toFloat(subQueryBoost * Math.min(maxBoost, factor)); + return scoreCombiner.combine(subQueryBoost, scorer.score(), + function.score(scorer.docID(), scorer.score()), maxBoost); } @Override @@ -182,7 +185,7 @@ public class FunctionScoreQuery extends Query { public String toString(String field) { StringBuilder sb = new StringBuilder(); - sb.append("custom score (").append(subQuery.toString(field)).append(",function=").append(function).append(')'); + sb.append("function score (").append(subQuery.toString(field)).append(",function=").append(function).append(')'); sb.append(ToStringUtils.boost(getBoost())); return sb.toString(); } @@ -198,15 +201,4 @@ public class FunctionScoreQuery extends Query { public int hashCode() { return subQuery.hashCode() + 31 * function.hashCode() ^ Float.floatToIntBits(getBoost()); } - - public static float toFloat(double input) { - assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input); - return (float) input; - } - - private static double deviation(double input) { // only with assert! - float floatVersion = (float)input; - return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d-(floatVersion) / input; - } - } diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java index 332f9d67a11..c56ce8649f8 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java @@ -25,12 +25,13 @@ import org.apache.lucene.search.Explanation; /** * */ -public class RandomScoreFunction implements ScoreFunction { +public class RandomScoreFunction extends ScoreFunction { private final PRNG prng; private int docBase; public RandomScoreFunction(long seed) { + super(CombineFunction.MULT); this.prng = new PRNG(seed); } @@ -44,11 +45,6 @@ public class RandomScoreFunction implements ScoreFunction { return prng.random(docBase + docId); } - @Override - public double factor(int docId) { - return prng.seed; - } - @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { Explanation exp = new Explanation(); @@ -57,14 +53,6 @@ public class RandomScoreFunction implements ScoreFunction { return exp; } - @Override - public Explanation explainFactor(int docId) { - Explanation exp = new Explanation(); - exp.setDescription("seed: " + prng.originalSeed + ")"); - return exp; - } - - /** * Algorithm largely based on {@link java.util.Random} except this one is not * thread safe and it incorporates the doc id on next(); diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java index b420213a14d..25c0149213b 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java @@ -25,15 +25,22 @@ import org.apache.lucene.search.Explanation; /** * */ -public interface ScoreFunction { +public abstract class ScoreFunction { - void setNextReader(AtomicReaderContext context); + private final CombineFunction scoreCombiner; + + public abstract void setNextReader(AtomicReaderContext context); - double score(int docId, float subQueryScore); + public abstract double score(int docId, float subQueryScore); - double factor(int docId); + public abstract Explanation explainScore(int docId, Explanation subQueryExpl); - Explanation explainScore(int docId, Explanation subQueryExpl); + public CombineFunction getDefaultScoreCombiner() { + return scoreCombiner; + } + + protected ScoreFunction(CombineFunction scoreCombiner) { + this.scoreCombiner = scoreCombiner; + } - Explanation explainFactor(int docId); } diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index 7fd53205820..579fdfc86c7 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -26,15 +26,17 @@ import org.elasticsearch.script.SearchScript; import java.util.Map; -public class ScriptScoreFunction implements ScoreFunction { +public class ScriptScoreFunction extends ScoreFunction { private final String sScript; private final Map params; private final SearchScript script; + public ScriptScoreFunction(String sScript, Map params, SearchScript script) { + super(CombineFunction.PLAIN); this.sScript = sScript; this.params = params; this.script = script; @@ -52,13 +54,6 @@ public class ScriptScoreFunction implements ScoreFunction { return script.runAsDouble(); } - @Override - public double factor(int docId) { - // just the factor, so don't provide _score - script.setNextDocId(docId); - return script.runAsFloat(); - } - @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { Explanation exp; @@ -68,19 +63,15 @@ public class ScriptScoreFunction implements ScoreFunction { exp = ((ExplainableSearchScript) script).explain(subQueryExpl); } else { double score = score(docId, subQueryExpl.getValue()); - exp = new Explanation((float)score, "script score function: composed of:"); + exp = new Explanation(CombineFunction.toFloat(score), "script score function: composed of:"); exp.addDetail(subQueryExpl); } return exp; } - @Override - public Explanation explainFactor(int docId) { - return new Explanation((float)factor(docId), "script_factor"); - } - @Override public String toString() { return "script[" + sScript + "], params [" + params + "]"; } + } \ No newline at end of file diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java index 686c6402573..207e89dc2ff 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java @@ -26,6 +26,7 @@ import org.elasticsearch.ElasticSearchIllegalArgumentException; import org.elasticsearch.ElasticSearchParseException; import org.elasticsearch.common.geo.GeoDistance; import org.elasticsearch.common.geo.GeoPoint; +import org.elasticsearch.common.lucene.search.function.CombineFunction; import org.elasticsearch.common.lucene.search.function.ScoreFunction; import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.unit.TimeValue; @@ -329,12 +330,13 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { * This is the base class for scoring a single field. * * */ - public static abstract class AbstractDistanceScoreFunction implements ScoreFunction { + public static abstract class AbstractDistanceScoreFunction extends ScoreFunction { private final double scale; private final DecayFunction func; public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, DecayFunction func) { + super(CombineFunction.MULT); if (userSuppiedScale <= 0.0) { throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : scale must be > 0.0."); } @@ -348,11 +350,6 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { @Override public double score(int docId, float subQueryScore) { - return (subQueryScore * factor(docId)); - } - - @Override - public double factor(int docId) { double value = distance(docId); return func.evaluate(value, scale); } @@ -372,19 +369,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { ComplexExplanation ce = new ComplexExplanation(); - ce.setValue((float) score(docId, subQueryExpl.getValue())); + ce.setValue(CombineFunction.toFloat(score(docId, subQueryExpl.getValue()))); ce.setMatch(true); - ce.setDescription("subQueryScore * Function for field " + getFieldName() + ":"); - ce.addDetail(func.explainFunction(getDistanceString(docId), distance(docId), scale)); - return ce; - } - - @Override - public Explanation explainFactor(int docId) { - ComplexExplanation ce = new ComplexExplanation(); - ce.setValue((float) factor(docId)); - ce.setMatch(true); - ce.setDescription("subQueryScore * Function for field " + getFieldName() + ":"); + ce.setDescription("Function for field " + getFieldName() + ":"); ce.addDetail(func.explainFunction(getDistanceString(docId), distance(docId), scale)); return ce; } diff --git a/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java b/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java index 731b5b2cec8..4cc6b2a8414 100644 --- a/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java +++ b/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java @@ -84,15 +84,15 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(3f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); // Same query but with boost searchResponse = client().prepareSearch("test") @@ -114,17 +114,17 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(6f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); - assertThat(explanation.getDetails()[1].getValue(), equalTo(6f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); - assertThat(explanation.getDetails()[1].getDetails()[2].getDescription(), equalTo("queryBoost")); - assertThat(explanation.getDetails()[1].getDetails()[2].getValue(), equalTo(2f)); + assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); + assertThat(explanation.getDetails()[2].getDescription(), equalTo("queryBoost")); + assertThat(explanation.getDetails()[2].getValue(), equalTo(2f)); } @@ -157,15 +157,14 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(3f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); - - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); // Same query but with boost searchResponse = client().prepareSearch("test") @@ -183,17 +182,17 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(6f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); - assertThat(explanation.getDetails()[1].getValue(), equalTo(6f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); - assertThat(explanation.getDetails()[1].getDetails()[2].getDescription(), equalTo("queryBoost")); - assertThat(explanation.getDetails()[1].getDetails()[2].getValue(), equalTo(2f)); + assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); + assertThat(explanation.getDetails()[2].getDescription(), equalTo("queryBoost")); + assertThat(explanation.getDetails()[2].getValue(), equalTo(2f)); } @Test