Merge pull request #14085 from brwe/fix-score-accuracy

fix numerical issue in function score query
This commit is contained in:
Britta Weber 2015-10-13 15:21:04 +02:00
commit 1634c8364e
2 changed files with 53 additions and 68 deletions

View File

@ -207,19 +207,11 @@ public class FiltersFunctionScoreQuery extends Query {
} }
// First: Gather explanations for all filters // First: Gather explanations for all filters
List<Explanation> filterExplanations = new ArrayList<>(); List<Explanation> filterExplanations = new ArrayList<>();
float weightSum = 0;
for (int i = 0; i < filterFunctions.length; ++i) { for (int i = 0; i < filterFunctions.length; ++i) {
FilterFunction filterFunction = filterFunctions[i];
if (filterFunction.function instanceof WeightFactorFunction) {
weightSum += ((WeightFactorFunction) filterFunction.function).getWeight();
} else {
weightSum++;
}
Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(), Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
filterWeights[i].scorer(context)); filterWeights[i].scorer(context));
if (docSet.get(doc)) { if (docSet.get(doc)) {
FilterFunction filterFunction = filterFunctions[i];
Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl); Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl);
double factor = functionExplanation.getValue(); double factor = functionExplanation.getValue();
float sc = CombineFunction.toFloat(factor); float sc = CombineFunction.toFloat(factor);
@ -232,44 +224,12 @@ public class FiltersFunctionScoreQuery extends Query {
return subQueryExpl; return subQueryExpl;
} }
// Second: Compute the factor that would have been computed by the FiltersFunctionFactorScorer scorer = (FiltersFunctionFactorScorer)scorer(context);
// filters int actualDoc = scorer.advance(doc);
double factor = 1.0; assert (actualDoc == doc);
switch (scoreMode) { double score = scorer.computeScore(doc, subQueryExpl.getValue());
case FIRST:
factor = filterExplanations.get(0).getValue();
break;
case MAX:
factor = Double.NEGATIVE_INFINITY;
for (Explanation filterExplanation : filterExplanations) {
factor = Math.max(filterExplanation.getValue(), factor);
}
break;
case MIN:
factor = Double.POSITIVE_INFINITY;
for (Explanation filterExplanation : filterExplanations) {
factor = Math.min(filterExplanation.getValue(), factor);
}
break;
case MULTIPLY:
for (Explanation filterExplanation : filterExplanations) {
factor *= filterExplanation.getValue();
}
break;
default:
double totalFactor = 0.0f;
for (Explanation filterExplanation : filterExplanations) {
totalFactor += filterExplanation.getValue();
}
if (weightSum != 0) {
factor = totalFactor;
if (scoreMode == ScoreMode.AVG) {
factor /= weightSum;
}
}
}
Explanation factorExplanation = Explanation.match( Explanation factorExplanation = Explanation.match(
CombineFunction.toFloat(factor), CombineFunction.toFloat(score),
"function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]", "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]",
filterExplanations); filterExplanations);
return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost); return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost);
@ -296,11 +256,16 @@ public class FiltersFunctionScoreQuery extends Query {
@Override @Override
public float innerScore() throws IOException { public float innerScore() throws IOException {
int docId = scorer.docID(); int docId = scorer.docID();
double factor = 1.0f;
// Even if the weight is created with needsScores=false, it might // Even if the weight is created with needsScores=false, it might
// be costly to call score(), so we explicitly check if scores // be costly to call score(), so we explicitly check if scores
// are needed // are needed
float subQueryScore = needsScores ? scorer.score() : 0f; float subQueryScore = needsScores ? scorer.score() : 0f;
double factor = computeScore(docId, subQueryScore);
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
}
protected double computeScore(int docId, float subQueryScore) {
double factor = 1d;
switch(scoreMode) { switch(scoreMode) {
case FIRST: case FIRST:
for (int i = 0; i < filterFunctions.length; i++) { for (int i = 0; i < filterFunctions.length; i++) {
@ -341,14 +306,14 @@ public class FiltersFunctionScoreQuery extends Query {
break; break;
default: // Avg / Total default: // Avg / Total
double totalFactor = 0.0f; double totalFactor = 0.0f;
float weightSum = 0; double weightSum = 0;
for (int i = 0; i < filterFunctions.length; i++) { for (int i = 0; i < filterFunctions.length; i++) {
if (docSets[i].get(docId)) { if (docSets[i].get(docId)) {
totalFactor += functions[i].score(docId, subQueryScore); totalFactor += functions[i].score(docId, subQueryScore);
if (filterFunctions[i].function instanceof WeightFactorFunction) { if (filterFunctions[i].function instanceof WeightFactorFunction) {
weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight(); weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight();
} else { } else {
weightSum++; weightSum += 1.0;
} }
} }
} }
@ -360,7 +325,7 @@ public class FiltersFunctionScoreQuery extends Query {
} }
break; break;
} }
return scoreCombiner.combine(subQueryScore, factor, maxBoost); return factor;
} }
} }

View File

@ -22,7 +22,6 @@ package org.elasticsearch.index.query.functionscore;
import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*; import org.apache.lucene.index.*;
import org.apache.lucene.search.*; import org.apache.lucene.search.*;
@ -47,7 +46,7 @@ import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsEqual.equalTo;
public class FunctionScoreTests extends ESTestCase { public class FunctionScoreTests extends ESTestCase {
@ -363,9 +362,12 @@ public class FunctionScoreTests extends ESTestCase {
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException { public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions); FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
return getExplanation(searcher, filtersFunctionScoreQuery).getDetails()[1];
}
protected Explanation getExplanation(IndexSearcher searcher, FiltersFunctionScoreQuery filtersFunctionScoreQuery) throws IOException {
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true); Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0); return weight.explain(searcher.getIndexReader().leaves().get(0), 0);
return explanation.getDetails()[1];
} }
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) { public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
@ -386,17 +388,25 @@ public class FunctionScoreTests extends ESTestCase {
} }
private static float[] randomFloats(int size) { private static float[] randomFloats(int size) {
float[] weights = new float[size]; float[] values = new float[size];
for (int i = 0; i < weights.length; i++) { for (int i = 0; i < values.length; i++) {
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f; values[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
} }
return weights; return values;
}
private static double[] randomDoubles(int size) {
double[] values = new double[size];
for (int i = 0; i < values.length; i++) {
values[i] = randomDouble() * (randomBoolean() ? 1.0d : -1.0d) * randomInt(100) + 1.e-5d;
}
return values;
} }
private static class ScoreFunctionStub extends ScoreFunction { private static class ScoreFunctionStub extends ScoreFunction {
private float score; private double score;
ScoreFunctionStub(float score) { ScoreFunctionStub(double score) {
super(CombineFunction.REPLACE); super(CombineFunction.REPLACE);
this.score = score; this.score = score;
} }
@ -411,7 +421,7 @@ public class FunctionScoreTests extends ESTestCase {
@Override @Override
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException { public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
throw new UnsupportedOperationException(UNSUPPORTED); return Explanation.match((float) score, "a random score for testing");
} }
}; };
} }
@ -431,7 +441,7 @@ public class FunctionScoreTests extends ESTestCase {
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException { public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
int numFunctions = randomIntBetween(1, 3); int numFunctions = randomIntBetween(1, 3);
float[] weights = randomFloats(numFunctions); float[] weights = randomFloats(numFunctions);
float[] scores = randomFloats(numFunctions); double[] scores = randomDoubles(numFunctions);
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions]; ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
for (int i = 0; i < numFunctions; i++) { for (int i = 0; i < numFunctions; i++) {
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]); scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
@ -453,7 +463,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) { for (int i = 0; i < weights.length; i++) {
score *= weights[i] * scores[i]; score *= weights[i] * scores[i];
} }
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d)); assertThat(scoreWithWeight / (float) score, is(1f));
float explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.SUM FiltersFunctionScoreQuery.ScoreMode.SUM
@ -467,7 +479,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) { for (int i = 0; i < weights.length; i++) {
sum += weights[i] * scores[i]; sum += weights[i] * scores[i];
} }
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d)); assertThat(scoreWithWeight / (float) sum, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.AVG FiltersFunctionScoreQuery.ScoreMode.AVG
@ -483,7 +497,9 @@ public class FunctionScoreTests extends ESTestCase {
norm += weights[i]; norm += weights[i];
sum += weights[i] * scores[i]; sum += weights[i] * scores[i];
} }
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d)); assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MIN FiltersFunctionScoreQuery.ScoreMode.MIN
@ -497,7 +513,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) { for (int i = 0; i < weights.length; i++) {
min = Math.min(min, weights[i] * scores[i]); min = Math.min(min, weights[i] * scores[i]);
} }
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d)); assertThat(scoreWithWeight / (float) min, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MAX FiltersFunctionScoreQuery.ScoreMode.MAX
@ -511,7 +529,9 @@ public class FunctionScoreTests extends ESTestCase {
for (int i = 0; i < weights.length; i++) { for (int i = 0; i < weights.length; i++) {
max = Math.max(max, weights[i] * scores[i]); max = Math.max(max, weights[i] * scores[i]);
} }
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d)); assertThat(scoreWithWeight / (float) max, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
} }
@Test @Test