Merge pull request #14085 from brwe/fix-score-accuracy

fix numerical issue in function score query
This commit is contained in:
Britta Weber 2015-10-13 15:21:04 +02:00
commit 1634c8364e
2 changed files with 53 additions and 68 deletions

View File

@ -207,19 +207,11 @@ public class FiltersFunctionScoreQuery extends Query {
}
// First: Gather explanations for all filters
List<Explanation> filterExplanations = new ArrayList<>();
float weightSum = 0;
for (int i = 0; i < filterFunctions.length; ++i) {
FilterFunction filterFunction = filterFunctions[i];
if (filterFunction.function instanceof WeightFactorFunction) {
weightSum += ((WeightFactorFunction) filterFunction.function).getWeight();
} else {
weightSum++;
}
Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
filterWeights[i].scorer(context));
if (docSet.get(doc)) {
FilterFunction filterFunction = filterFunctions[i];
Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl);
double factor = functionExplanation.getValue();
float sc = CombineFunction.toFloat(factor);
@ -232,44 +224,12 @@ public class FiltersFunctionScoreQuery extends Query {
return subQueryExpl;
}
// Second: Compute the factor that would have been computed by the
// filters
double factor = 1.0;
switch (scoreMode) {
case FIRST:
factor = filterExplanations.get(0).getValue();
break;
case MAX:
factor = Double.NEGATIVE_INFINITY;
for (Explanation filterExplanation : filterExplanations) {
factor = Math.max(filterExplanation.getValue(), factor);
}
break;
case MIN:
factor = Double.POSITIVE_INFINITY;
for (Explanation filterExplanation : filterExplanations) {
factor = Math.min(filterExplanation.getValue(), factor);
}
break;
case MULTIPLY:
for (Explanation filterExplanation : filterExplanations) {
factor *= filterExplanation.getValue();
}
break;
default:
double totalFactor = 0.0f;
for (Explanation filterExplanation : filterExplanations) {
totalFactor += filterExplanation.getValue();
}
if (weightSum != 0) {
factor = totalFactor;
if (scoreMode == ScoreMode.AVG) {
factor /= weightSum;
}
}
}
FiltersFunctionFactorScorer scorer = (FiltersFunctionFactorScorer)scorer(context);
int actualDoc = scorer.advance(doc);
assert (actualDoc == doc);
double score = scorer.computeScore(doc, subQueryExpl.getValue());
Explanation factorExplanation = Explanation.match(
CombineFunction.toFloat(factor),
CombineFunction.toFloat(score),
"function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]",
filterExplanations);
return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost);
@ -296,11 +256,16 @@ public class FiltersFunctionScoreQuery extends Query {
@Override
public float innerScore() throws IOException {
int docId = scorer.docID();
double factor = 1.0f;
// Even if the weight is created with needsScores=false, it might
// be costly to call score(), so we explicitly check if scores
// are needed
float subQueryScore = needsScores ? scorer.score() : 0f;
double factor = computeScore(docId, subQueryScore);
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
}
protected double computeScore(int docId, float subQueryScore) {
double factor = 1d;
switch(scoreMode) {
case FIRST:
for (int i = 0; i < filterFunctions.length; i++) {
@ -341,14 +306,14 @@ public class FiltersFunctionScoreQuery extends Query {
break;
default: // Avg / Total
double totalFactor = 0.0f;
float weightSum = 0;
double weightSum = 0;
for (int i = 0; i < filterFunctions.length; i++) {
if (docSets[i].get(docId)) {
totalFactor += functions[i].score(docId, subQueryScore);
if (filterFunctions[i].function instanceof WeightFactorFunction) {
weightSum+= ((WeightFactorFunction)filterFunctions[i].function).getWeight();
weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight();
} else {
weightSum++;
weightSum += 1.0;
}
}
}
@ -360,7 +325,7 @@ public class FiltersFunctionScoreQuery extends Query {
}
break;
}
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
return factor;
}
}

View File

@ -22,7 +22,6 @@ package org.elasticsearch.index.query.functionscore;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
@ -47,7 +46,7 @@ import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.ExecutionException;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsEqual.equalTo;
public class FunctionScoreTests extends ESTestCase {
@ -363,9 +362,12 @@ public class FunctionScoreTests extends ESTestCase {
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
return getExplanation(searcher, filtersFunctionScoreQuery).getDetails()[1];
}
protected Explanation getExplanation(IndexSearcher searcher, FiltersFunctionScoreQuery filtersFunctionScoreQuery) throws IOException {
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0);
return explanation.getDetails()[1];
return weight.explain(searcher.getIndexReader().leaves().get(0), 0);
}
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
@ -386,17 +388,25 @@ public class FunctionScoreTests extends ESTestCase {
}
private static float[] randomFloats(int size) {
float[] weights = new float[size];
for (int i = 0; i < weights.length; i++) {
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
float[] values = new float[size];
for (int i = 0; i < values.length; i++) {
values[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
}
return weights;
return values;
}
private static double[] randomDoubles(int size) {
double[] values = new double[size];
for (int i = 0; i < values.length; i++) {
values[i] = randomDouble() * (randomBoolean() ? 1.0d : -1.0d) * randomInt(100) + 1.e-5d;
}
return values;
}
private static class ScoreFunctionStub extends ScoreFunction {
private float score;
private double score;
ScoreFunctionStub(float score) {
ScoreFunctionStub(double score) {
super(CombineFunction.REPLACE);
this.score = score;
}
@ -411,7 +421,7 @@ public class FunctionScoreTests extends ESTestCase {
@Override
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
throw new UnsupportedOperationException(UNSUPPORTED);
return Explanation.match((float) score, "a random score for testing");
}
};
}
@ -431,7 +441,7 @@ public class FunctionScoreTests extends ESTestCase {
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
int numFunctions = randomIntBetween(1, 3);
float[] weights = randomFloats(numFunctions);
float[] scores = randomFloats(numFunctions);
double[] scores = randomDoubles(numFunctions);
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
for (int i = 0; i < numFunctions; i++) {
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
@ -453,7 +463,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
score *= weights[i] * scores[i];
}
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) score, is(1f));
float explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.SUM
@ -467,7 +479,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
sum += weights[i] * scores[i];
}
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) sum, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.AVG
@ -483,7 +497,9 @@ public class FunctionScoreTests extends ESTestCase {
norm += weights[i];
sum += weights[i] * scores[i];
}
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MIN
@ -497,7 +513,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
min = Math.min(min, weights[i] * scores[i]);
}
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) min, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MAX
@ -511,7 +529,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) {
max = Math.max(max, weights[i] * scores[i]);
}
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d));
assertThat(scoreWithWeight / (float) max, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
}
@Test