Merge pull request #14085 from brwe/fix-score-accuracy
fix numerical issue in function score query
This commit is contained in:
commit
1634c8364e
|
@ -207,19 +207,11 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
}
|
}
|
||||||
// First: Gather explanations for all filters
|
// First: Gather explanations for all filters
|
||||||
List<Explanation> filterExplanations = new ArrayList<>();
|
List<Explanation> filterExplanations = new ArrayList<>();
|
||||||
float weightSum = 0;
|
|
||||||
for (int i = 0; i < filterFunctions.length; ++i) {
|
for (int i = 0; i < filterFunctions.length; ++i) {
|
||||||
FilterFunction filterFunction = filterFunctions[i];
|
|
||||||
|
|
||||||
if (filterFunction.function instanceof WeightFactorFunction) {
|
|
||||||
weightSum += ((WeightFactorFunction) filterFunction.function).getWeight();
|
|
||||||
} else {
|
|
||||||
weightSum++;
|
|
||||||
}
|
|
||||||
|
|
||||||
Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
|
Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
|
||||||
filterWeights[i].scorer(context));
|
filterWeights[i].scorer(context));
|
||||||
if (docSet.get(doc)) {
|
if (docSet.get(doc)) {
|
||||||
|
FilterFunction filterFunction = filterFunctions[i];
|
||||||
Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl);
|
Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl);
|
||||||
double factor = functionExplanation.getValue();
|
double factor = functionExplanation.getValue();
|
||||||
float sc = CombineFunction.toFloat(factor);
|
float sc = CombineFunction.toFloat(factor);
|
||||||
|
@ -232,44 +224,12 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
return subQueryExpl;
|
return subQueryExpl;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second: Compute the factor that would have been computed by the
|
FiltersFunctionFactorScorer scorer = (FiltersFunctionFactorScorer)scorer(context);
|
||||||
// filters
|
int actualDoc = scorer.advance(doc);
|
||||||
double factor = 1.0;
|
assert (actualDoc == doc);
|
||||||
switch (scoreMode) {
|
double score = scorer.computeScore(doc, subQueryExpl.getValue());
|
||||||
case FIRST:
|
|
||||||
factor = filterExplanations.get(0).getValue();
|
|
||||||
break;
|
|
||||||
case MAX:
|
|
||||||
factor = Double.NEGATIVE_INFINITY;
|
|
||||||
for (Explanation filterExplanation : filterExplanations) {
|
|
||||||
factor = Math.max(filterExplanation.getValue(), factor);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case MIN:
|
|
||||||
factor = Double.POSITIVE_INFINITY;
|
|
||||||
for (Explanation filterExplanation : filterExplanations) {
|
|
||||||
factor = Math.min(filterExplanation.getValue(), factor);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case MULTIPLY:
|
|
||||||
for (Explanation filterExplanation : filterExplanations) {
|
|
||||||
factor *= filterExplanation.getValue();
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
double totalFactor = 0.0f;
|
|
||||||
for (Explanation filterExplanation : filterExplanations) {
|
|
||||||
totalFactor += filterExplanation.getValue();
|
|
||||||
}
|
|
||||||
if (weightSum != 0) {
|
|
||||||
factor = totalFactor;
|
|
||||||
if (scoreMode == ScoreMode.AVG) {
|
|
||||||
factor /= weightSum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Explanation factorExplanation = Explanation.match(
|
Explanation factorExplanation = Explanation.match(
|
||||||
CombineFunction.toFloat(factor),
|
CombineFunction.toFloat(score),
|
||||||
"function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]",
|
"function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]",
|
||||||
filterExplanations);
|
filterExplanations);
|
||||||
return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost);
|
return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost);
|
||||||
|
@ -296,11 +256,16 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
@Override
|
@Override
|
||||||
public float innerScore() throws IOException {
|
public float innerScore() throws IOException {
|
||||||
int docId = scorer.docID();
|
int docId = scorer.docID();
|
||||||
double factor = 1.0f;
|
|
||||||
// Even if the weight is created with needsScores=false, it might
|
// Even if the weight is created with needsScores=false, it might
|
||||||
// be costly to call score(), so we explicitly check if scores
|
// be costly to call score(), so we explicitly check if scores
|
||||||
// are needed
|
// are needed
|
||||||
float subQueryScore = needsScores ? scorer.score() : 0f;
|
float subQueryScore = needsScores ? scorer.score() : 0f;
|
||||||
|
double factor = computeScore(docId, subQueryScore);
|
||||||
|
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected double computeScore(int docId, float subQueryScore) {
|
||||||
|
double factor = 1d;
|
||||||
switch(scoreMode) {
|
switch(scoreMode) {
|
||||||
case FIRST:
|
case FIRST:
|
||||||
for (int i = 0; i < filterFunctions.length; i++) {
|
for (int i = 0; i < filterFunctions.length; i++) {
|
||||||
|
@ -341,14 +306,14 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
break;
|
break;
|
||||||
default: // Avg / Total
|
default: // Avg / Total
|
||||||
double totalFactor = 0.0f;
|
double totalFactor = 0.0f;
|
||||||
float weightSum = 0;
|
double weightSum = 0;
|
||||||
for (int i = 0; i < filterFunctions.length; i++) {
|
for (int i = 0; i < filterFunctions.length; i++) {
|
||||||
if (docSets[i].get(docId)) {
|
if (docSets[i].get(docId)) {
|
||||||
totalFactor += functions[i].score(docId, subQueryScore);
|
totalFactor += functions[i].score(docId, subQueryScore);
|
||||||
if (filterFunctions[i].function instanceof WeightFactorFunction) {
|
if (filterFunctions[i].function instanceof WeightFactorFunction) {
|
||||||
weightSum+= ((WeightFactorFunction)filterFunctions[i].function).getWeight();
|
weightSum += ((WeightFactorFunction) filterFunctions[i].function).getWeight();
|
||||||
} else {
|
} else {
|
||||||
weightSum++;
|
weightSum += 1.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -360,7 +325,7 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
|
return factor;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.elasticsearch.index.query.functionscore;
|
||||||
import org.apache.lucene.analysis.standard.StandardAnalyzer;
|
import org.apache.lucene.analysis.standard.StandardAnalyzer;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.FieldType;
|
|
||||||
import org.apache.lucene.document.TextField;
|
import org.apache.lucene.document.TextField;
|
||||||
import org.apache.lucene.index.*;
|
import org.apache.lucene.index.*;
|
||||||
import org.apache.lucene.search.*;
|
import org.apache.lucene.search.*;
|
||||||
|
@ -47,7 +46,7 @@ import java.io.IOException;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.closeTo;
|
import static org.hamcrest.core.Is.is;
|
||||||
import static org.hamcrest.core.IsEqual.equalTo;
|
import static org.hamcrest.core.IsEqual.equalTo;
|
||||||
|
|
||||||
public class FunctionScoreTests extends ESTestCase {
|
public class FunctionScoreTests extends ESTestCase {
|
||||||
|
@ -363,9 +362,12 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
|
|
||||||
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
|
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
|
||||||
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
|
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
|
||||||
|
return getExplanation(searcher, filtersFunctionScoreQuery).getDetails()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Explanation getExplanation(IndexSearcher searcher, FiltersFunctionScoreQuery filtersFunctionScoreQuery) throws IOException {
|
||||||
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
|
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
|
||||||
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0);
|
return weight.explain(searcher.getIndexReader().leaves().get(0), 0);
|
||||||
return explanation.getDetails()[1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
|
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
|
||||||
|
@ -386,17 +388,25 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float[] randomFloats(int size) {
|
private static float[] randomFloats(int size) {
|
||||||
float[] weights = new float[size];
|
float[] values = new float[size];
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < values.length; i++) {
|
||||||
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
|
values[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
|
||||||
}
|
}
|
||||||
return weights;
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double[] randomDoubles(int size) {
|
||||||
|
double[] values = new double[size];
|
||||||
|
for (int i = 0; i < values.length; i++) {
|
||||||
|
values[i] = randomDouble() * (randomBoolean() ? 1.0d : -1.0d) * randomInt(100) + 1.e-5d;
|
||||||
|
}
|
||||||
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class ScoreFunctionStub extends ScoreFunction {
|
private static class ScoreFunctionStub extends ScoreFunction {
|
||||||
private float score;
|
private double score;
|
||||||
|
|
||||||
ScoreFunctionStub(float score) {
|
ScoreFunctionStub(double score) {
|
||||||
super(CombineFunction.REPLACE);
|
super(CombineFunction.REPLACE);
|
||||||
this.score = score;
|
this.score = score;
|
||||||
}
|
}
|
||||||
|
@ -411,7 +421,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
|
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
|
||||||
throw new UnsupportedOperationException(UNSUPPORTED);
|
return Explanation.match((float) score, "a random score for testing");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -431,7 +441,7 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
|
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
|
||||||
int numFunctions = randomIntBetween(1, 3);
|
int numFunctions = randomIntBetween(1, 3);
|
||||||
float[] weights = randomFloats(numFunctions);
|
float[] weights = randomFloats(numFunctions);
|
||||||
float[] scores = randomFloats(numFunctions);
|
double[] scores = randomDoubles(numFunctions);
|
||||||
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
|
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
|
||||||
for (int i = 0; i < numFunctions; i++) {
|
for (int i = 0; i < numFunctions; i++) {
|
||||||
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
|
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
|
||||||
|
@ -453,7 +463,9 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
score *= weights[i] * scores[i];
|
score *= weights[i] * scores[i];
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) score, is(1f));
|
||||||
|
float explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||||
|
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.SUM
|
FiltersFunctionScoreQuery.ScoreMode.SUM
|
||||||
|
@ -467,7 +479,9 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
sum += weights[i] * scores[i];
|
sum += weights[i] * scores[i];
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) sum, is(1f));
|
||||||
|
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||||
|
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.AVG
|
FiltersFunctionScoreQuery.ScoreMode.AVG
|
||||||
|
@ -483,7 +497,9 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
norm += weights[i];
|
norm += weights[i];
|
||||||
sum += weights[i] * scores[i];
|
sum += weights[i] * scores[i];
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
|
||||||
|
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||||
|
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.MIN
|
FiltersFunctionScoreQuery.ScoreMode.MIN
|
||||||
|
@ -497,7 +513,9 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
min = Math.min(min, weights[i] * scores[i]);
|
min = Math.min(min, weights[i] * scores[i]);
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) min, is(1f));
|
||||||
|
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||||
|
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||||
|
|
||||||
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
|
||||||
FiltersFunctionScoreQuery.ScoreMode.MAX
|
FiltersFunctionScoreQuery.ScoreMode.MAX
|
||||||
|
@ -511,7 +529,9 @@ public class FunctionScoreTests extends ESTestCase {
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for (int i = 0; i < weights.length; i++) {
|
||||||
max = Math.max(max, weights[i] * scores[i]);
|
max = Math.max(max, weights[i] * scores[i]);
|
||||||
}
|
}
|
||||||
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d));
|
assertThat(scoreWithWeight / (float) max, is(1f));
|
||||||
|
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
|
||||||
|
assertThat(explainedScore / scoreWithWeight, is(1f));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue