fix numerical issue in function score query

we should sum the weights as double to not lose precision. also,
the tests should simulate exactly what function score does and then test
for equality of scores.
This commit is contained in:
Britta Weber 2015-10-13 11:54:37 +02:00
parent 7557eae9e0
commit 17ce5d5242
2 changed files with 24 additions and 15 deletions

View File

@ -341,14 +341,14 @@ public class FiltersFunctionScoreQuery extends Query {
break;
default: // Avg / Total
double totalFactor = 0.0f;
float weightSum = 0;
double weightSum = 0;
for (int i = 0; i < filterFunctions.length; i++) {
if (docSets[i].get(docId)) {
totalFactor += functions[i].score(docId, subQueryScore);
if (filterFunctions[i].function instanceof WeightFactorFunction) {
weightSum+= ((WeightFactorFunction)filterFunctions[i].function).getWeight();
weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight();
} else {
weightSum++;
weightSum += 1.0;
}
}
}

View File

@ -48,6 +48,7 @@ import java.util.Collection;
import java.util.concurrent.ExecutionException;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsEqual.equalTo;
public class FunctionScoreTests extends ESTestCase {
@ -396,17 +397,25 @@ public class FunctionScoreTests extends ESTestCase {
}
private static float[] randomFloats(int size) {
float[] weights = new float[size];
for (int i = 0; i < weights.length; i++) {
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
float[] values = new float[size];
for (int i = 0; i < values.length; i++) {
values[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
}
return weights;
return values;
}
private static double[] randomDoubles(int size) {
double[] values = new double[size];
for (int i = 0; i < values.length; i++) {
values[i] = randomDouble() * (randomBoolean() ? 1.0d : -1.0d) * randomInt(100) + 1.e-5d;
}
return values;
}
private static class ScoreFunctionStub extends ScoreFunction {
private float score;
private double score;
ScoreFunctionStub(float score) {
ScoreFunctionStub(double score) {
super(CombineFunction.REPLACE);
this.score = score;
}
@ -441,7 +450,7 @@ public class FunctionScoreTests extends ESTestCase {
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
int numFunctions = randomIntBetween(1, 3);
float[] weights = randomFloats(numFunctions);
float[] scores = randomFloats(numFunctions);
double[] scores = randomDoubles(numFunctions);
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
for (int i = 0; i < numFunctions; i++) {
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
@ -463,7 +472,7 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
score *= weights[i] * scores[i];
}
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) score, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.SUM
@ -477,7 +486,7 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
sum += weights[i] * scores[i];
}
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) sum, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.AVG
@ -493,7 +502,7 @@ public class FunctionScoreTests extends ESTestCase {
norm += weights[i];
sum += weights[i] * scores[i];
}
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MIN
@ -507,7 +516,7 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
min = Math.min(min, weights[i] * scores[i]);
}
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) min, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MAX
@ -521,7 +530,7 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
max = Math.max(max, weights[i] * scores[i]);
}
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) max, is(1f));
}
@Test