add assertion for cast double->float

ScoreFunction scoring might result in under or overflow, for example if a user
decides to use the timestamp as a boost in the script scorer. Therefore, check
if cast causes a huge precision loss. Note that this does not always detect
casting issues. For example in
ScriptFunction.score()
the function
SearchScript.runAsDouble()
is called. AbstractFloatSearchScript implements it as follows:
@Override
    public double runAsDouble() {
        return runAsFloat();
    }
In this case the cast happens before the assertion and therfore precision
lossor over/underflows cannot be detected by the assertion.
This commit is contained in:
Britta Weber 2013-08-06 16:21:59 +02:00
parent e707308f1f
commit a938bd57a9
6 changed files with 35 additions and 22 deletions

View File

@ -33,7 +33,6 @@ public class BoostScoreFunction implements ScoreFunction {
this.boost = boost; this.boost = boost;
} }
public float getBoost() { public float getBoost() {
return boost; return boost;
} }
@ -44,7 +43,7 @@ public class BoostScoreFunction implements ScoreFunction {
} }
@Override @Override
public float score(int docId, float subQueryScore) { public double score(int docId, float subQueryScore) {
return subQueryScore * boost; return subQueryScore * boost;
} }
@ -68,12 +67,15 @@ public class BoostScoreFunction implements ScoreFunction {
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o)
if (o == null || getClass() != o.getClass()) return false; return true;
if (o == null || getClass() != o.getClass())
return false;
BoostScoreFunction that = (BoostScoreFunction) o; BoostScoreFunction that = (BoostScoreFunction) o;
if (Float.compare(that.boost, boost) != 0) return false; if (Float.compare(that.boost, boost) != 0)
return false;
return true; return true;
} }

View File

@ -164,11 +164,11 @@ public class FiltersFunctionScoreQuery extends Query {
if (docSet.get(doc)) { if (docSet.get(doc)) {
filterFunction.function.setNextReader(context); filterFunction.function.setNextReader(context);
Explanation functionExplanation = filterFunction.function.explainFactor(doc); Explanation functionExplanation = filterFunction.function.explainFactor(doc);
float factor = functionExplanation.getValue(); double factor = functionExplanation.getValue();
if (factor > maxBoost) { if (factor > maxBoost) {
factor = maxBoost; factor = maxBoost;
} }
float sc = getBoost() * factor; float sc = toFloat(getBoost() * factor);
Explanation filterExplanation = new ComplexExplanation(true, sc, "custom score, product of:"); Explanation filterExplanation = new ComplexExplanation(true, sc, "custom score, product of:");
filterExplanation.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); filterExplanation.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
filterExplanation.addDetail(functionExplanation); filterExplanation.addDetail(functionExplanation);
@ -186,21 +186,21 @@ public class FiltersFunctionScoreQuery extends Query {
int count = 0; int count = 0;
float total = 0; float total = 0;
float multiply = 1; float multiply = 1;
float max = Float.NEGATIVE_INFINITY; double max = Double.NEGATIVE_INFINITY;
float min = Float.POSITIVE_INFINITY; double min = Double.POSITIVE_INFINITY;
ArrayList<Explanation> filtersExplanations = new ArrayList<Explanation>(); ArrayList<Explanation> filtersExplanations = new ArrayList<Explanation>();
for (FilterFunction filterFunction : filterFunctions) { for (FilterFunction filterFunction : filterFunctions) {
Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs()));
if (docSet.get(doc)) { if (docSet.get(doc)) {
filterFunction.function.setNextReader(context); filterFunction.function.setNextReader(context);
Explanation functionExplanation = filterFunction.function.explainFactor(doc); Explanation functionExplanation = filterFunction.function.explainFactor(doc);
float factor = functionExplanation.getValue(); double factor = functionExplanation.getValue();
count++; count++;
total += factor; total += factor;
multiply *= factor; multiply *= factor;
max = Math.max(factor, max); max = Math.max(factor, max);
min = Math.min(factor, min); min = Math.min(factor, min);
Explanation res = new ComplexExplanation(true, factor, "custom score, product of:"); Explanation res = new ComplexExplanation(true, toFloat(factor), "custom score, product of:");
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString())); res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
res.addDetail(functionExplanation); res.addDetail(functionExplanation);
res.addDetail(new Explanation(getBoost(), "queryBoost")); res.addDetail(new Explanation(getBoost(), "queryBoost"));
@ -208,7 +208,7 @@ public class FiltersFunctionScoreQuery extends Query {
} }
} }
if (count > 0) { if (count > 0) {
float factor = 0; double factor = 0;
switch (scoreMode) { switch (scoreMode) {
case Avg: case Avg:
factor = total / count; factor = total / count;
@ -230,7 +230,7 @@ public class FiltersFunctionScoreQuery extends Query {
if (factor > maxBoost) { if (factor > maxBoost) {
factor = maxBoost; factor = maxBoost;
} }
float sc = factor * subQueryExpl.getValue() * getBoost(); float sc = toFloat(factor * subQueryExpl.getValue() * getBoost());
Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]"); Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]");
res.addDetail(subQueryExpl); res.addDetail(subQueryExpl);
for (Explanation explanation : filtersExplanations) { for (Explanation explanation : filtersExplanations) {
@ -341,7 +341,7 @@ public class FiltersFunctionScoreQuery extends Query {
factor = maxBoost; factor = maxBoost;
} }
float score = scorer.score(); float score = scorer.score();
return (float)(subQueryBoost * score * factor); return toFloat(subQueryBoost * score * factor);
} }
@Override @Override
@ -381,5 +381,10 @@ public class FiltersFunctionScoreQuery extends Query {
public int hashCode() { public int hashCode() {
return subQuery.hashCode() + 31 * Arrays.hashCode(filterFunctions) ^ Float.floatToIntBits(getBoost()); return subQuery.hashCode() + 31 * Arrays.hashCode(filterFunctions) ^ Float.floatToIntBits(getBoost());
} }
public static float toFloat(double input) {
assert Double.compare(((float) input), input) == 0 || (Math.abs(((float) input) - input) <= 0.001);
return (float) input;
}
} }

View File

@ -62,7 +62,7 @@ public class FunctionScoreQuery extends Query {
@Override @Override
public Query rewrite(IndexReader reader) throws IOException { public Query rewrite(IndexReader reader) throws IOException {
Query newQ = subQuery.rewrite(reader); Query newQ = subQuery.rewrite(reader);
if (newQ == subQuery){ if (newQ == subQuery) {
return this; return this;
} }
FunctionScoreQuery bq = (FunctionScoreQuery) this.clone(); FunctionScoreQuery bq = (FunctionScoreQuery) this.clone();
@ -165,8 +165,8 @@ public class FunctionScoreQuery extends Query {
@Override @Override
public float score() throws IOException { public float score() throws IOException {
float factor = (float)function.score(scorer.docID(), scorer.score()); double factor = function.score(scorer.docID(), scorer.score());
return subQueryBoost * Math.min(maxBoost, factor); return toFloat(subQueryBoost * Math.min(maxBoost, factor));
} }
@Override @Override
@ -198,4 +198,10 @@ public class FunctionScoreQuery extends Query {
public int hashCode() { public int hashCode() {
return subQuery.hashCode() + 31 * function.hashCode() ^ Float.floatToIntBits(getBoost()); return subQuery.hashCode() + 31 * function.hashCode() ^ Float.floatToIntBits(getBoost());
} }
public static float toFloat(double input) {
assert Double.compare(((float) input), input) == 0 || (Math.abs(((float) input) - input) <= 0.001);
return (float) input;
}
} }

View File

@ -29,7 +29,7 @@ public interface ScoreFunction {
void setNextReader(AtomicReaderContext context); void setNextReader(AtomicReaderContext context);
float score(int docId, float subQueryScore); double score(int docId, float subQueryScore);
double factor(int docId); double factor(int docId);

View File

@ -46,10 +46,10 @@ public class ScriptScoreFunction implements ScoreFunction {
} }
@Override @Override
public float score(int docId, float subQueryScore) { public double score(int docId, float subQueryScore) {
script.setNextDocId(docId); script.setNextDocId(docId);
script.setNextScore(subQueryScore); script.setNextScore(subQueryScore);
return script.runAsFloat(); return script.runAsDouble();
} }
@Override @Override

View File

@ -359,8 +359,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
} }
@Override @Override
public float score(int docId, float subQueryScore) { public double score(int docId, float subQueryScore) {
return (float) (subQueryScore * factor(docId)); return (subQueryScore * factor(docId));
} }
@Override @Override