diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java index c62605ec1f5..c06c50c96d2 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java @@ -25,11 +25,12 @@ import org.apache.lucene.search.Explanation; /** * */ -public class BoostScoreFunction implements ScoreFunction { +public class BoostScoreFunction extends ScoreFunction { private final float boost; public BoostScoreFunction(float boost) { + super(CombineFunction.MULT); this.boost = boost; } @@ -41,30 +42,19 @@ public class BoostScoreFunction implements ScoreFunction { public void setNextReader(AtomicReaderContext context) { // nothing to do here... } - + @Override public double score(int docId, float subQueryScore) { - return subQueryScore * boost; - } - - @Override - public double factor(int docId) { return boost; } @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { - Explanation exp = new Explanation(boost * subQueryExpl.getValue(), "static boost function: product of:"); - exp.addDetail(subQueryExpl); + Explanation exp = new Explanation(boost, "static boost factor"); exp.addDetail(new Explanation(boost, "boostFactor")); return exp; } - @Override - public Explanation explainFactor(int docId) { - return new Explanation(boost, "boostFactor"); - } - @Override public boolean equals(Object o) { if (this == o) diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java new file mode 100644 index 00000000000..ec93746c879 --- /dev/null +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/CombineFunction.java @@ -0,0 +1,90 @@ +/* + * Licensed to ElasticSearch and Shay Banon under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. ElasticSearch licenses this + * file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.lucene.search.function; + +import org.apache.lucene.search.ComplexExplanation; +import org.apache.lucene.search.Explanation; + +public enum CombineFunction { + MULT { + @Override + public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) { + return toFloat(queryBoost * queryScore * Math.min(funcScore, maxBoost)); + } + + @Override + public String getName() { + return "mult"; + } + + @Override + public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) { + float score = queryBoost * Math.min(funcExpl.getValue(), maxBoost) * queryExpl.getValue(); + ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:"); + res.addDetail(queryExpl); + ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of"); + minExpl.addDetail(funcExpl); + minExpl.addDetail(new Explanation(maxBoost, "maxBoost")); + res.addDetail(minExpl); + res.addDetail(new Explanation(queryBoost, "queryBoost")); + return res; + } + }, + PLAIN { + @Override + public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) { + return toFloat(queryBoost * Math.min(funcScore, maxBoost)); + } + + @Override + public String getName() { + return "plain"; + } + + @Override + public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) { + float score = queryBoost * Math.min(funcExpl.getValue(), maxBoost); + ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:"); + ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of"); + minExpl.addDetail(funcExpl); + minExpl.addDetail(new Explanation(maxBoost, "maxBoost")); + res.addDetail(minExpl); + res.addDetail(new Explanation(queryBoost, "queryBoost")); + return res; + } + + }; + + public abstract float combine(double queryBoost, double queryScore, double funcScore, double maxBoost); + + public abstract String getName(); + + public static float toFloat(double input) { + assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input); + return (float) input; + } + + private static double deviation(double input) { // only with assert! + float floatVersion = (float)input; + return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d-(floatVersion) / input; + } + + public abstract ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost); +} diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java index 708c3aa0b8b..b4ba7fb9c4b 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java @@ -28,14 +28,11 @@ import org.apache.lucene.util.ToStringUtils; import org.elasticsearch.common.lucene.docset.DocIdSets; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Locale; -import java.util.Set; +import java.util.*; /** - * A query that allows for a pluggable boost function / filter. If it matches the filter, it will - * be boosted by the formula. + * A query that allows for a pluggable boost function / filter. If it matches + * the filter, it will be boosted by the formula. */ public class FiltersFunctionScoreQuery extends Query { @@ -50,13 +47,17 @@ public class FiltersFunctionScoreQuery extends Query { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; FilterFunction that = (FilterFunction) o; - if (filter != null ? !filter.equals(that.filter) : that.filter != null) return false; - if (function != null ? !function.equals(that.function) : that.function != null) return false; + if (filter != null ? !filter.equals(that.filter) : that.filter != null) + return false; + if (function != null ? !function.equals(that.function) : that.function != null) + return false; return true; } @@ -69,20 +70,29 @@ public class FiltersFunctionScoreQuery extends Query { } } - public static enum ScoreMode {First, Avg, Max, Total, Min, Multiply} + public static enum ScoreMode { + First, Avg, Max, Total, Min, Multiply + } Query subQuery; final FilterFunction[] filterFunctions; final ScoreMode scoreMode; final float maxBoost; + protected CombineFunction combineFunction; + public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost) { this.subQuery = subQuery; this.scoreMode = scoreMode; this.filterFunctions = filterFunctions; this.maxBoost = maxBoost; + combineFunction = CombineFunction.MULT; } + public FiltersFunctionScoreQuery setCombineFunction(CombineFunction combineFunction){ + this.combineFunction = combineFunction; + return this; + } public Query getSubQuery() { return subQuery; } @@ -94,7 +104,8 @@ public class FiltersFunctionScoreQuery extends Query { @Override public Query rewrite(IndexReader reader) throws IOException { Query newQ = subQuery.rewrite(reader); - if (newQ == subQuery) return this; + if (newQ == subQuery) + return this; FiltersFunctionScoreQuery bq = (FiltersFunctionScoreQuery) this.clone(); bq.subQuery = newQ; return bq; @@ -148,107 +159,88 @@ public class FiltersFunctionScoreQuery extends Query { filterFunction.function.setNextReader(context); docSets[i] = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, acceptDocs)); } - return new CustomBoostFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets); + return new CustomBoostFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets, combineFunction); } @Override public Explanation explain(AtomicReaderContext context, int doc) throws IOException { + Explanation subQueryExpl = subQueryWeight.explain(context, doc); if (!subQueryExpl.isMatch()) { return subQueryExpl; } - - if (scoreMode == ScoreMode.First) { - for (FilterFunction filterFunction : filterFunctions) { - Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); - if (docSet.get(doc)) { - filterFunction.function.setNextReader(context); - Explanation functionExplanation = filterFunction.function.explainFactor(doc); - double factor = functionExplanation.getValue(); - if (factor > maxBoost) { - factor = maxBoost; - } - float sc = FunctionScoreQuery.toFloat(getBoost() * factor); - Explanation filterExplanation = new ComplexExplanation(true, sc, "function score, product of:"); - filterExplanation.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); - filterExplanation.addDetail(functionExplanation); - filterExplanation.addDetail(new Explanation(getBoost(), "queryBoost")); - - // top level score = subquery.score * filter.score (this already has the query boost) - float topLevelScore = subQueryExpl.getValue() * sc; - Explanation topLevel = new ComplexExplanation(true, topLevelScore, "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); - topLevel.addDetail(subQueryExpl); - topLevel.addDetail(filterExplanation); - return topLevel; - } - } - } else { - int count = 0; - float total = 0; - float multiply = 1; - double max = Double.NEGATIVE_INFINITY; - double min = Double.POSITIVE_INFINITY; - ArrayList filtersExplanations = new ArrayList(); - for (FilterFunction filterFunction : filterFunctions) { - Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); - if (docSet.get(doc)) { - filterFunction.function.setNextReader(context); - Explanation functionExplanation = filterFunction.function.explainFactor(doc); - double factor = functionExplanation.getValue(); - count++; - total += factor; - multiply *= factor; - max = Math.max(factor, max); - min = Math.min(factor, min); - Explanation res = new ComplexExplanation(true, FunctionScoreQuery.toFloat(factor), "function score, product of:"); - res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); - res.addDetail(functionExplanation); - res.addDetail(new Explanation(getBoost(), "queryBoost")); - filtersExplanations.add(res); - } - } - if (count > 0) { - double factor = 0; - switch (scoreMode) { - case Avg: - factor = total / count; - break; - case Max: - factor = max; - break; - case Min: - factor = min; - break; - case Total: - factor = total; - break; - case Multiply: - factor = multiply; - break; - } - - if (factor > maxBoost) { - factor = maxBoost; - } - float sc = FunctionScoreQuery.toFloat(factor * subQueryExpl.getValue() * getBoost()); - Explanation res = new ComplexExplanation(true, sc, "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); - res.addDetail(subQueryExpl); - for (Explanation explanation : filtersExplanations) { - res.addDetail(explanation); - } - return res; + // First: Gather explanations for all filters + List filterExplanations = new ArrayList(); + for (FilterFunction filterFunction : filterFunctions) { + Bits docSet = DocIdSets.toSafeBits(context.reader(), + filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); + if (docSet.get(doc)) { + filterFunction.function.setNextReader(context); + Explanation functionExplanation = filterFunction.function.explainScore(doc, subQueryExpl); + double factor = functionExplanation.getValue(); + float sc = CombineFunction.toFloat(factor); + ComplexExplanation filterExplanation = new ComplexExplanation(true, sc, "function score, product of:"); + filterExplanation.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); + filterExplanation.addDetail(functionExplanation); + filterExplanations.add(filterExplanation); } } + if (filterExplanations.size() == 0) { + float sc = getBoost() * subQueryExpl.getValue(); + Explanation res = new ComplexExplanation(true, sc, "function score, no filter match, product of:"); + res.addDetail(subQueryExpl); + res.addDetail(new Explanation(getBoost(), "queryBoost")); + return res; + } - float sc = getBoost() * subQueryExpl.getValue(); - Explanation res = new ComplexExplanation(true, sc, "custom score, no filter match, product of:"); - res.addDetail(subQueryExpl); - res.addDetail(new Explanation(getBoost(), "queryBoost")); - return res; + // Second: Compute the factor that would have been computed by the + // filters + double factor = 1.0; + switch (scoreMode) { + case First: + + factor = filterExplanations.get(0).getValue(); + break; + case Max: + double maxFactor = Double.NEGATIVE_INFINITY; + for (int i = 0; i < filterExplanations.size(); i++) { + factor = Math.max(filterExplanations.get(i).getValue(), maxFactor); + } + break; + case Min: + double minFactor = Double.POSITIVE_INFINITY; + for (int i = 0; i < filterExplanations.size(); i++) { + factor = Math.min(filterExplanations.get(i).getValue(), minFactor); + } + break; + case Multiply: + for (int i = 0; i < filterExplanations.size(); i++) { + factor *= filterExplanations.get(i).getValue(); + } + break; + default: // Avg / Total + double totalFactor = 0.0f; + int count = 0; + for (int i = 0; i < filterExplanations.size(); i++) { + totalFactor += filterExplanations.get(i).getValue(); + count++; + } + if (count != 0) { + factor = totalFactor; + if (scoreMode == ScoreMode.Avg) { + factor /= count; + } + } + } + ComplexExplanation factorExplanaition = new ComplexExplanation(true, CombineFunction.toFloat(factor), + "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); + for (int i = 0; i < filterExplanations.size(); i++) { + factorExplanaition.addDetail(filterExplanations.get(i)); + } + return combineFunction.explain(getBoost(), subQueryExpl, factorExplanaition, maxBoost); } } - static class CustomBoostFactorScorer extends Scorer { private final float subQueryBoost; @@ -257,9 +249,10 @@ public class FiltersFunctionScoreQuery extends Query { private final ScoreMode scoreMode; private final float maxBoost; private final Bits[] docSets; + private final CombineFunction scoreCombiner; - private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, - FilterFunction[] filterFunctions, float maxBoost, Bits[] docSets) throws IOException { + private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions, + float maxBoost, Bits[] docSets, CombineFunction scoreCombiner) throws IOException { super(w); this.subQueryBoost = w.getQuery().getBoost(); this.scorer = scorer; @@ -267,6 +260,7 @@ public class FiltersFunctionScoreQuery extends Query { this.filterFunctions = filterFunctions; this.maxBoost = maxBoost; this.docSets = docSets; + this.scoreCombiner = scoreCombiner; } @Override @@ -288,10 +282,11 @@ public class FiltersFunctionScoreQuery extends Query { public float score() throws IOException { int docId = scorer.docID(); double factor = 1.0f; + float subQueryScore = scorer.score(); if (scoreMode == ScoreMode.First) { for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - factor = filterFunctions[i].function.factor(docId); + factor = filterFunctions[i].function.score(docId, subQueryScore); break; } } @@ -299,7 +294,7 @@ public class FiltersFunctionScoreQuery extends Query { double maxFactor = Double.NEGATIVE_INFINITY; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - maxFactor = Math.max(filterFunctions[i].function.factor(docId), maxFactor); + maxFactor = Math.max(filterFunctions[i].function.score(docId, subQueryScore), maxFactor); } } if (maxFactor != Float.NEGATIVE_INFINITY) { @@ -309,7 +304,7 @@ public class FiltersFunctionScoreQuery extends Query { double minFactor = Double.POSITIVE_INFINITY; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - minFactor = Math.min(filterFunctions[i].function.factor(docId), minFactor); + minFactor = Math.min(filterFunctions[i].function.score(docId, subQueryScore), minFactor); } } if (minFactor != Float.POSITIVE_INFINITY) { @@ -318,7 +313,7 @@ public class FiltersFunctionScoreQuery extends Query { } else if (scoreMode == ScoreMode.Multiply) { for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - factor *= filterFunctions[i].function.factor(docId); + factor *= filterFunctions[i].function.score(docId, subQueryScore); } } } else { // Avg / Total @@ -326,7 +321,7 @@ public class FiltersFunctionScoreQuery extends Query { int count = 0; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { - totalFactor += filterFunctions[i].function.factor(docId); + totalFactor += filterFunctions[i].function.score(docId, subQueryScore); count++; } } @@ -337,11 +332,7 @@ public class FiltersFunctionScoreQuery extends Query { } } } - if (factor > maxBoost) { - factor = maxBoost; - } - float score = scorer.score(); - return FunctionScoreQuery.toFloat(subQueryBoost * score * factor); + return scoreCombiner.combine(subQueryBoost, subQueryScore, factor, maxBoost); } @Override @@ -355,10 +346,9 @@ public class FiltersFunctionScoreQuery extends Query { } } - public String toString(String field) { StringBuilder sb = new StringBuilder(); - sb.append("custom score (").append(subQuery.toString(field)).append(", functions: ["); + sb.append("function score (").append(subQuery.toString(field)).append(", functions: ["); for (FilterFunction filterFunction : filterFunctions) { sb.append("{filter(").append(filterFunction.filter).append("), function [").append(filterFunction.function).append("]}"); } @@ -368,7 +358,8 @@ public class FiltersFunctionScoreQuery extends Query { } public boolean equals(Object o) { - if (getClass() != o.getClass()) return false; + if (getClass() != o.getClass()) + return false; FiltersFunctionScoreQuery other = (FiltersFunctionScoreQuery) o; if (this.getBoost() != other.getBoost()) return false; @@ -382,4 +373,3 @@ public class FiltersFunctionScoreQuery extends Query { return subQuery.hashCode() + 31 * Arrays.hashCode(filterFunctions) ^ Float.floatToIntBits(getBoost()); } } - diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java index b2ef9ca9c0a..251fc80ec54 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java @@ -37,12 +37,18 @@ public class FunctionScoreQuery extends Query { Query subQuery; final ScoreFunction function; float maxBoost = Float.MAX_VALUE; - + CombineFunction combineFunction; + public FunctionScoreQuery(Query subQuery, ScoreFunction function) { this.subQuery = subQuery; this.function = function; + this.combineFunction = function.getDefaultScoreCombiner(); } + public void setCombineFunction(CombineFunction combineFunction) { + this.combineFunction = combineFunction; + } + public void setMaxBoost(float maxBoost) { this.maxBoost = maxBoost; } @@ -112,7 +118,7 @@ public class FunctionScoreQuery extends Query { return null; } function.setNextReader(context); - return new CustomBoostFactorScorer(this, subQueryScorer, function, maxBoost); + return new CustomBoostFactorScorer(this, subQueryScorer, function, maxBoost, combineFunction); } @Override @@ -121,14 +127,9 @@ public class FunctionScoreQuery extends Query { if (!subQueryExpl.isMatch()) { return subQueryExpl; } - function.setNextReader(context); Explanation functionExplanation = function.explainScore(doc, subQueryExpl); - float sc = getBoost() * functionExplanation.getValue(); - Explanation res = new ComplexExplanation(true, sc, "function score, product of:"); - res.addDetail(functionExplanation); - res.addDetail(new Explanation(getBoost(), "queryBoost")); - return res; + return combineFunction.explain(getBoost(), subQueryExpl, functionExplanation, maxBoost); } } @@ -138,14 +139,16 @@ public class FunctionScoreQuery extends Query { private final Scorer scorer; private final ScoreFunction function; private final float maxBoost; + private final CombineFunction scoreCombiner; - private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost) + private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost, CombineFunction scoreCombiner) throws IOException { super(w); this.subQueryBoost = w.getQuery().getBoost(); this.scorer = scorer; this.function = function; this.maxBoost = maxBoost; + this.scoreCombiner = scoreCombiner; } @Override @@ -165,8 +168,8 @@ public class FunctionScoreQuery extends Query { @Override public float score() throws IOException { - double factor = function.score(scorer.docID(), scorer.score()); - return toFloat(subQueryBoost * Math.min(maxBoost, factor)); + return scoreCombiner.combine(subQueryBoost, scorer.score(), + function.score(scorer.docID(), scorer.score()), maxBoost); } @Override @@ -182,7 +185,7 @@ public class FunctionScoreQuery extends Query { public String toString(String field) { StringBuilder sb = new StringBuilder(); - sb.append("custom score (").append(subQuery.toString(field)).append(",function=").append(function).append(')'); + sb.append("function score (").append(subQuery.toString(field)).append(",function=").append(function).append(')'); sb.append(ToStringUtils.boost(getBoost())); return sb.toString(); } @@ -198,15 +201,4 @@ public class FunctionScoreQuery extends Query { public int hashCode() { return subQuery.hashCode() + 31 * function.hashCode() ^ Float.floatToIntBits(getBoost()); } - - public static float toFloat(double input) { - assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input); - return (float) input; - } - - private static double deviation(double input) { // only with assert! - float floatVersion = (float)input; - return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d-(floatVersion) / input; - } - } diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java index 332f9d67a11..c56ce8649f8 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/RandomScoreFunction.java @@ -25,12 +25,13 @@ import org.apache.lucene.search.Explanation; /** * */ -public class RandomScoreFunction implements ScoreFunction { +public class RandomScoreFunction extends ScoreFunction { private final PRNG prng; private int docBase; public RandomScoreFunction(long seed) { + super(CombineFunction.MULT); this.prng = new PRNG(seed); } @@ -44,11 +45,6 @@ public class RandomScoreFunction implements ScoreFunction { return prng.random(docBase + docId); } - @Override - public double factor(int docId) { - return prng.seed; - } - @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { Explanation exp = new Explanation(); @@ -57,14 +53,6 @@ public class RandomScoreFunction implements ScoreFunction { return exp; } - @Override - public Explanation explainFactor(int docId) { - Explanation exp = new Explanation(); - exp.setDescription("seed: " + prng.originalSeed + ")"); - return exp; - } - - /** * Algorithm largely based on {@link java.util.Random} except this one is not * thread safe and it incorporates the doc id on next(); diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java index b420213a14d..25c0149213b 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java @@ -25,15 +25,22 @@ import org.apache.lucene.search.Explanation; /** * */ -public interface ScoreFunction { +public abstract class ScoreFunction { - void setNextReader(AtomicReaderContext context); + private final CombineFunction scoreCombiner; + + public abstract void setNextReader(AtomicReaderContext context); - double score(int docId, float subQueryScore); + public abstract double score(int docId, float subQueryScore); - double factor(int docId); + public abstract Explanation explainScore(int docId, Explanation subQueryExpl); - Explanation explainScore(int docId, Explanation subQueryExpl); + public CombineFunction getDefaultScoreCombiner() { + return scoreCombiner; + } + + protected ScoreFunction(CombineFunction scoreCombiner) { + this.scoreCombiner = scoreCombiner; + } - Explanation explainFactor(int docId); } diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index 7fd53205820..579fdfc86c7 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -26,15 +26,17 @@ import org.elasticsearch.script.SearchScript; import java.util.Map; -public class ScriptScoreFunction implements ScoreFunction { +public class ScriptScoreFunction extends ScoreFunction { private final String sScript; private final Map params; private final SearchScript script; + public ScriptScoreFunction(String sScript, Map params, SearchScript script) { + super(CombineFunction.PLAIN); this.sScript = sScript; this.params = params; this.script = script; @@ -52,13 +54,6 @@ public class ScriptScoreFunction implements ScoreFunction { return script.runAsDouble(); } - @Override - public double factor(int docId) { - // just the factor, so don't provide _score - script.setNextDocId(docId); - return script.runAsFloat(); - } - @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { Explanation exp; @@ -68,19 +63,15 @@ public class ScriptScoreFunction implements ScoreFunction { exp = ((ExplainableSearchScript) script).explain(subQueryExpl); } else { double score = score(docId, subQueryExpl.getValue()); - exp = new Explanation((float)score, "script score function: composed of:"); + exp = new Explanation(CombineFunction.toFloat(score), "script score function: composed of:"); exp.addDetail(subQueryExpl); } return exp; } - @Override - public Explanation explainFactor(int docId) { - return new Explanation((float)factor(docId), "script_factor"); - } - @Override public String toString() { return "script[" + sScript + "], params [" + params + "]"; } + } \ No newline at end of file diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java index 686c6402573..207e89dc2ff 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java @@ -26,6 +26,7 @@ import org.elasticsearch.ElasticSearchIllegalArgumentException; import org.elasticsearch.ElasticSearchParseException; import org.elasticsearch.common.geo.GeoDistance; import org.elasticsearch.common.geo.GeoPoint; +import org.elasticsearch.common.lucene.search.function.CombineFunction; import org.elasticsearch.common.lucene.search.function.ScoreFunction; import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.unit.TimeValue; @@ -329,12 +330,13 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { * This is the base class for scoring a single field. * * */ - public static abstract class AbstractDistanceScoreFunction implements ScoreFunction { + public static abstract class AbstractDistanceScoreFunction extends ScoreFunction { private final double scale; private final DecayFunction func; public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, DecayFunction func) { + super(CombineFunction.MULT); if (userSuppiedScale <= 0.0) { throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : scale must be > 0.0."); } @@ -348,11 +350,6 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { @Override public double score(int docId, float subQueryScore) { - return (subQueryScore * factor(docId)); - } - - @Override - public double factor(int docId) { double value = distance(docId); return func.evaluate(value, scale); } @@ -372,19 +369,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { @Override public Explanation explainScore(int docId, Explanation subQueryExpl) { ComplexExplanation ce = new ComplexExplanation(); - ce.setValue((float) score(docId, subQueryExpl.getValue())); + ce.setValue(CombineFunction.toFloat(score(docId, subQueryExpl.getValue()))); ce.setMatch(true); - ce.setDescription("subQueryScore * Function for field " + getFieldName() + ":"); - ce.addDetail(func.explainFunction(getDistanceString(docId), distance(docId), scale)); - return ce; - } - - @Override - public Explanation explainFactor(int docId) { - ComplexExplanation ce = new ComplexExplanation(); - ce.setValue((float) factor(docId)); - ce.setMatch(true); - ce.setDescription("subQueryScore * Function for field " + getFieldName() + ":"); + ce.setDescription("Function for field " + getFieldName() + ":"); ce.addDetail(func.explainFunction(getDistanceString(docId), distance(docId), scale)); return ce; } diff --git a/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java b/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java index 731b5b2cec8..4cc6b2a8414 100644 --- a/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java +++ b/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java @@ -84,15 +84,15 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(3f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); // Same query but with boost searchResponse = client().prepareSearch("test") @@ -114,17 +114,17 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(6f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); - assertThat(explanation.getDetails()[1].getValue(), equalTo(6f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); - assertThat(explanation.getDetails()[1].getDetails()[2].getDescription(), equalTo("queryBoost")); - assertThat(explanation.getDetails()[1].getDetails()[2].getValue(), equalTo(2f)); + assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); + assertThat(explanation.getDetails()[2].getDescription(), equalTo("queryBoost")); + assertThat(explanation.getDetails()[2].getValue(), equalTo(2f)); } @@ -157,15 +157,14 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(3f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); - - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); // Same query but with boost searchResponse = client().prepareSearch("test") @@ -183,17 +182,17 @@ public class CustomScoreSearchTests extends AbstractSharedClusterTest { assertNotNull(explanation); assertThat(explanation.isMatch(), equalTo(true)); assertThat(explanation.getValue(), equalTo(6f)); - assertThat(explanation.getDescription(), equalTo("function score, score mode [first]")); + assertThat(explanation.getDescription(), equalTo("function score, product of:")); - assertThat(explanation.getDetails().length, equalTo(2)); + assertThat(explanation.getDetails().length, equalTo(3)); assertThat(explanation.getDetails()[0].isMatch(), equalTo(true)); assertThat(explanation.getDetails()[0].getValue(), equalTo(1f)); assertThat(explanation.getDetails()[0].getDetails().length, equalTo(2)); assertThat(explanation.getDetails()[1].isMatch(), equalTo(true)); - assertThat(explanation.getDetails()[1].getValue(), equalTo(6f)); - assertThat(explanation.getDetails()[1].getDetails().length, equalTo(3)); - assertThat(explanation.getDetails()[1].getDetails()[2].getDescription(), equalTo("queryBoost")); - assertThat(explanation.getDetails()[1].getDetails()[2].getValue(), equalTo(2f)); + assertThat(explanation.getDetails()[1].getValue(), equalTo(3f)); + assertThat(explanation.getDetails()[1].getDetails().length, equalTo(2)); + assertThat(explanation.getDetails()[2].getDescription(), equalTo("queryBoost")); + assertThat(explanation.getDetails()[2].getValue(), equalTo(2f)); } @Test