fix numerical issue in function score query
we should sum the weights as double to not lose precision. also, the tests should simulate exactly what function score does and then test for equality of scores.
This commit is contained in:
parent
7557eae9e0
commit
17ce5d5242
|
@ -341,14 +341,14 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
break;
|
break;
|
||||||
default: // Avg / Total
|
default: // Avg / Total
|
||||||
double totalFactor = 0.0f;
|
double totalFactor = 0.0f;
|
||||||
float weightSum = 0;
|
double weightSum = 0;
|
||||||
for (int i = 0; i < filterFunctions.length; i++) {
|
for (int i = 0; i < filterFunctions.length; i++) {
|
||||||
if (docSets[i].get(docId)) {
|
if (docSets[i].get(docId)) {
|
||||||
totalFactor += functions[i].score(docId, subQueryScore);
|
totalFactor += functions[i].score(docId, subQueryScore);
|
||||||
if (filterFunctions[i].function instanceof WeightFactorFunction) {
|
if (filterFunctions[i].function instanceof WeightFactorFunction) {
|
||||||
weightSum+= ((WeightFactorFunction)filterFunctions[i].function).getWeight();
|
weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight();
|
||||||
} else {
|
} else {
|
||||||
weightSum++;
|
weightSum += 1.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,6 +48,7 @@ import java.util.Collection;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.closeTo;
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
import static org.hamcrest.core.Is.is;
|
||||||
import static org.hamcrest.core.IsEqual.equalTo;
|
import static org.hamcrest.core.IsEqual.equalTo;
|
||||||
|
|
||||||
public class FunctionScoreTests extends ESTestCase {
|
public class FunctionScoreTests extends ESTestCase {
|
||||||
|
@ -396,17 +397,25 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float[] randomFloats(int size) {
|
private static float[] randomFloats(int size) {
|
||||||
float[] weights = new float[size];
|
float[] values = new float[size];
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < values.length; i++) {
|
||||||
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
|
values[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
|
||||||
}
|
}
|
||||||
return weights;
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double[] randomDoubles(int size) {
|
||||||
|
double[] values = new double[size];
|
||||||
|
for (int i = 0; i < values.length; i++) {
|
||||||
|
values[i] = randomDouble() * (randomBoolean() ? 1.0d : -1.0d) * randomInt(100) + 1.e-5d;
|
||||||
|
}
|
||||||
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class ScoreFunctionStub extends ScoreFunction {
|
private static class ScoreFunctionStub extends ScoreFunction {
|
||||||
private float score;
|
private double score;
|
||||||
|
|
||||||
ScoreFunctionStub(float score) {
|
ScoreFunctionStub(double score) {
|
||||||
super(CombineFunction.REPLACE);
|
super(CombineFunction.REPLACE);
|
||||||
this.score = score;
|
this.score = score;
|
||||||
}
|
}
|
||||||
|
@ -441,7 +450,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
|
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
|
||||||
int numFunctions = randomIntBetween(1, 3);
|
int numFunctions = randomIntBetween(1, 3);
|
||||||
float[] weights = randomFloats(numFunctions);
|
float[] weights = randomFloats(numFunctions);
|
||||||
float[] scores = randomFloats(numFunctions);
|
double[] scores = randomDoubles(numFunctions);
|
||||||
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
|
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
|
||||||
for (int i = 0; i < numFunctions; i++) {
|
for (int i = 0; i < numFunctions; i++) {
|
||||||
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
|
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
|
||||||
|
@ -463,7 +472,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
score *= weights[i] * scores[i];
|
score *= weights[i] * scores[i];
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) score, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.SUM
|
FiltersFunctionScoreQuery.ScoreMode.SUM
|
||||||
|
@ -477,7 +486,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
sum += weights[i] * scores[i];
|
sum += weights[i] * scores[i];
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) sum, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.AVG
|
FiltersFunctionScoreQuery.ScoreMode.AVG
|
||||||
|
@ -493,7 +502,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
norm += weights[i];
|
norm += weights[i];
|
||||||
sum += weights[i] * scores[i];
|
sum += weights[i] * scores[i];
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.MIN
|
FiltersFunctionScoreQuery.ScoreMode.MIN
|
||||||
|
@ -507,7 +516,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
min = Math.min(min, weights[i] * scores[i]);
|
min = Math.min(min, weights[i] * scores[i]);
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) min, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.MAX
|
FiltersFunctionScoreQuery.ScoreMode.MAX
|
||||||
|
@ -521,7 +530,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
max = Math.max(max, weights[i] * scores[i]);
|
max = Math.max(max, weights[i] * scores[i]);
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) max, is(1f));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue