Merge pull request #14085 from brwe/fix-score-accuracy
fix numerical issue in function score query
This commit is contained in:
commit
1634c8364e
|
@ -207,19 +207,11 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
}
|
||||
// First: Gather explanations for all filters
|
||||
List<Explanation> filterExplanations = new ArrayList<>();
|
||||
float weightSum = 0;
|
||||
for (int i = 0; i < filterFunctions.length; ++i) {
|
||||
FilterFunction filterFunction = filterFunctions[i];
|
||||
|
||||
if (filterFunction.function instanceof WeightFactorFunction) {
|
||||
weightSum += ((WeightFactorFunction) filterFunction.function).getWeight();
|
||||
} else {
|
||||
weightSum++;
|
||||
}
|
||||
|
||||
Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
|
||||
filterWeights[i].scorer(context));
|
||||
if (docSet.get(doc)) {
|
||||
FilterFunction filterFunction = filterFunctions[i];
|
||||
Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl);
|
||||
double factor = functionExplanation.getValue();
|
||||
float sc = CombineFunction.toFloat(factor);
|
||||
|
@ -232,44 +224,12 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
return subQueryExpl;
|
||||
}
|
||||
|
||||
// Second: Compute the factor that would have been computed by the
|
||||
// filters
|
||||
double factor = 1.0;
|
||||
switch (scoreMode) {
|
||||
case FIRST:
|
||||
factor = filterExplanations.get(0).getValue();
|
||||
break;
|
||||
case MAX:
|
||||
factor = Double.NEGATIVE_INFINITY;
|
||||
for (Explanation filterExplanation : filterExplanations) {
|
||||
factor = Math.max(filterExplanation.getValue(), factor);
|
||||
}
|
||||
break;
|
||||
case MIN:
|
||||
factor = Double.POSITIVE_INFINITY;
|
||||
for (Explanation filterExplanation : filterExplanations) {
|
||||
factor = Math.min(filterExplanation.getValue(), factor);
|
||||
}
|
||||
break;
|
||||
case MULTIPLY:
|
||||
for (Explanation filterExplanation : filterExplanations) {
|
||||
factor *= filterExplanation.getValue();
|
||||
}
|
||||
break;
|
||||
default:
|
||||
double totalFactor = 0.0f;
|
||||
for (Explanation filterExplanation : filterExplanations) {
|
||||
totalFactor += filterExplanation.getValue();
|
||||
}
|
||||
if (weightSum != 0) {
|
||||
factor = totalFactor;
|
||||
if (scoreMode == ScoreMode.AVG) {
|
||||
factor /= weightSum;
|
||||
}
|
||||
}
|
||||
}
|
||||
FiltersFunctionFactorScorer scorer = (FiltersFunctionFactorScorer)scorer(context);
|
||||
int actualDoc = scorer.advance(doc);
|
||||
assert (actualDoc == doc);
|
||||
double score = scorer.computeScore(doc, subQueryExpl.getValue());
|
||||
Explanation factorExplanation = Explanation.match(
|
||||
CombineFunction.toFloat(factor),
|
||||
CombineFunction.toFloat(score),
|
||||
"function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]",
|
||||
filterExplanations);
|
||||
return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost);
|
||||
|
@ -296,11 +256,16 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
@Override
|
||||
public float innerScore() throws IOException {
|
||||
int docId = scorer.docID();
|
||||
double factor = 1.0f;
|
||||
// Even if the weight is created with needsScores=false, it might
|
||||
// be costly to call score(), so we explicitly check if scores
|
||||
// are needed
|
||||
float subQueryScore = needsScores ? scorer.score() : 0f;
|
||||
double factor = computeScore(docId, subQueryScore);
|
||||
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
|
||||
}
|
||||
|
||||
protected double computeScore(int docId, float subQueryScore) {
|
||||
double factor = 1d;
|
||||
switch(scoreMode) {
|
||||
case FIRST:
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
|
@ -341,14 +306,14 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
break;
|
||||
default: // Avg / Total
|
||||
double totalFactor = 0.0f;
|
||||
float weightSum = 0;
|
||||
double weightSum = 0;
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
if (docSets[i].get(docId)) {
|
||||
totalFactor += functions[i].score(docId, subQueryScore);
|
||||
if (filterFunctions[i].function instanceof WeightFactorFunction) {
|
||||
weightSum+= ((WeightFactorFunction)filterFunctions[i].function).getWeight();
|
||||
weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight();
|
||||
} else {
|
||||
weightSum++;
|
||||
weightSum += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -360,7 +325,7 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
}
|
||||
break;
|
||||
}
|
||||
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
|
||||
return factor;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ package org.elasticsearch.index.query.functionscore;
|
|||
import org.apache.lucene.analysis.standard.StandardAnalyzer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.*;
|
||||
import org.apache.lucene.search.*;
|
||||
|
@ -47,7 +46,7 @@ import java.io.IOException;
|
|||
import java.util.Collection;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.hamcrest.core.IsEqual.equalTo;
|
||||
|
||||
public class FunctionScoreTests extends ESTestCase {
|
||||
|
@ -363,9 +362,12 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
|
||||
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
|
||||
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
|
||||
return getExplanation(searcher, filtersFunctionScoreQuery).getDetails()[1];
|
||||
}
|
||||
|
||||
protected Explanation getExplanation(IndexSearcher searcher, FiltersFunctionScoreQuery filtersFunctionScoreQuery) throws IOException {
|
||||
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
|
||||
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0);
|
||||
return explanation.getDetails()[1];
|
||||
return weight.explain(searcher.getIndexReader().leaves().get(0), 0);
|
||||
}
|
||||
|
||||
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
|
||||
|
@ -386,17 +388,25 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static float[] randomFloats(int size) {
|
||||
float[] weights = new float[size];
|
||||
for (int i = 0; i < weights.length; i++) {
|
||||
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
|
||||
float[] values = new float[size];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
values[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
|
||||
}
|
||||
return weights;
|
||||
return values;
|
||||
}
|
||||
|
||||
private static double[] randomDoubles(int size) {
|
||||
double[] values = new double[size];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
values[i] = randomDouble() * (randomBoolean() ? 1.0d : -1.0d) * randomInt(100) + 1.e-5d;
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
private static class ScoreFunctionStub extends ScoreFunction {
|
||||
private float score;
|
||||
private double score;
|
||||
|
||||
ScoreFunctionStub(float score) {
|
||||
ScoreFunctionStub(double score) {
|
||||
super(CombineFunction.REPLACE);
|
||||
this.score = score;
|
||||
}
|
||||
|
@ -411,7 +421,7 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
|
||||
@Override
|
||||
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
|
||||
throw new UnsupportedOperationException(UNSUPPORTED);
|
||||
return Explanation.match((float) score, "a random score for testing");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -431,7 +441,7 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
|
||||
int numFunctions = randomIntBetween(1, 3);
|
||||
float[] weights = randomFloats(numFunctions);
|
||||
float[] scores = randomFloats(numFunctions);
|
||||
double[] scores = randomDoubles(numFunctions);
|
||||
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
|
||||
for (int i = 0; i < numFunctions; i++) {
|
||||
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
|
||||
|
@ -453,7 +463,9 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
for (int i = 0; i < weights.length; i++) {
|
||||
score *= weights[i] * scores[i];
|
||||
}
|
||||
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d));
|
||||
assertThat(scoreWithWeight / (float) score, is(1f));
|
||||
float explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||
|
||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||
FiltersFunctionScoreQuery.ScoreMode.SUM
|
||||
|
@ -467,7 +479,9 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
for (int i = 0; i < weights.length; i++) {
|
||||
sum += weights[i] * scores[i];
|
||||
}
|
||||
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d));
|
||||
assertThat(scoreWithWeight / (float) sum, is(1f));
|
||||
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||
|
||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||
FiltersFunctionScoreQuery.ScoreMode.AVG
|
||||
|
@ -483,7 +497,9 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
norm += weights[i];
|
||||
sum += weights[i] * scores[i];
|
||||
}
|
||||
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d));
|
||||
assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
|
||||
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||
|
||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||
FiltersFunctionScoreQuery.ScoreMode.MIN
|
||||
|
@ -497,7 +513,9 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
for (int i = 0; i < weights.length; i++) {
|
||||
min = Math.min(min, weights[i] * scores[i]);
|
||||
}
|
||||
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d));
|
||||
assertThat(scoreWithWeight / (float) min, is(1f));
|
||||
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||
|
||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||
FiltersFunctionScoreQuery.ScoreMode.MAX
|
||||
|
@ -511,7 +529,9 @@ public class FunctionScoreTests extends ESTestCase {
|
|||
for (int i = 0; i < weights.length; i++) {
|
||||
max = Math.max(max, weights[i] * scores[i]);
|
||||
}
|
||||
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d));
|
||||
assertThat(scoreWithWeight / (float) max, is(1f));
|
||||
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue