use same score computation for actual scoring and explanation

FiltersFunctionScoreQuery sums up scores and weights and scores as double but when
we explain we cannot get the double scores from the explanation of score
functions. as a result we cannot compute the exact score from the explanations
of the functions alone.
this commit makes the explanation more accurate but also causes the score to be
computed one additional time.
This commit is contained in:
Britta Weber 2015-10-13 13:05:14 +02:00
parent 17ce5d5242
commit 318dfba464
2 changed files with 28 additions and 53 deletions

View File

@ -207,19 +207,11 @@ public class FiltersFunctionScoreQuery extends Query {
} }
// First: Gather explanations for all filters // First: Gather explanations for all filters
List<Explanation> filterExplanations = new ArrayList<>(); List<Explanation> filterExplanations = new ArrayList<>();
float weightSum = 0;
for (int i = 0; i < filterFunctions.length; ++i) { for (int i = 0; i < filterFunctions.length; ++i) {
FilterFunction filterFunction = filterFunctions[i];
if (filterFunction.function instanceof WeightFactorFunction) {
weightSum += ((WeightFactorFunction) filterFunction.function).getWeight();
} else {
weightSum++;
}
Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(), Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
filterWeights[i].scorer(context)); filterWeights[i].scorer(context));
if (docSet.get(doc)) { if (docSet.get(doc)) {
FilterFunction filterFunction = filterFunctions[i];
Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl); Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, subQueryExpl);
double factor = functionExplanation.getValue(); double factor = functionExplanation.getValue();
float sc = CombineFunction.toFloat(factor); float sc = CombineFunction.toFloat(factor);
@ -232,44 +224,11 @@ public class FiltersFunctionScoreQuery extends Query {
return subQueryExpl; return subQueryExpl;
} }
// Second: Compute the factor that would have been computed by the FiltersFunctionFactorScorer scorer = (FiltersFunctionFactorScorer)scorer(context);
// filters scorer.advance(doc);
double factor = 1.0; double score = scorer.computeScore(doc, subQueryExpl.getValue());
switch (scoreMode) {
case FIRST:
factor = filterExplanations.get(0).getValue();
break;
case MAX:
factor = Double.NEGATIVE_INFINITY;
for (Explanation filterExplanation : filterExplanations) {
factor = Math.max(filterExplanation.getValue(), factor);
}
break;
case MIN:
factor = Double.POSITIVE_INFINITY;
for (Explanation filterExplanation : filterExplanations) {
factor = Math.min(filterExplanation.getValue(), factor);
}
break;
case MULTIPLY:
for (Explanation filterExplanation : filterExplanations) {
factor *= filterExplanation.getValue();
}
break;
default:
double totalFactor = 0.0f;
for (Explanation filterExplanation : filterExplanations) {
totalFactor += filterExplanation.getValue();
}
if (weightSum != 0) {
factor = totalFactor;
if (scoreMode == ScoreMode.AVG) {
factor /= weightSum;
}
}
}
Explanation factorExplanation = Explanation.match( Explanation factorExplanation = Explanation.match(
CombineFunction.toFloat(factor), CombineFunction.toFloat(score),
"function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]", "function score, score mode [" + scoreMode.toString().toLowerCase(Locale.ROOT) + "]",
filterExplanations); filterExplanations);
return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost); return combineFunction.explain(subQueryExpl, factorExplanation, maxBoost);
@ -296,11 +255,16 @@ public class FiltersFunctionScoreQuery extends Query {
@Override @Override
public float innerScore() throws IOException { public float innerScore() throws IOException {
int docId = scorer.docID(); int docId = scorer.docID();
double factor = 1.0f;
// Even if the weight is created with needsScores=false, it might // Even if the weight is created with needsScores=false, it might
// be costly to call score(), so we explicitly check if scores // be costly to call score(), so we explicitly check if scores
// are needed // are needed
float subQueryScore = needsScores ? scorer.score() : 0f; float subQueryScore = needsScores ? scorer.score() : 0f;
double factor = computeScore(docId, subQueryScore);
return scoreCombiner.combine(subQueryScore, factor, maxBoost);
}
protected double computeScore(int docId, float subQueryScore) {
double factor = 1d;
switch(scoreMode) { switch(scoreMode) {
case FIRST: case FIRST:
for (int i = 0; i < filterFunctions.length; i++) { for (int i = 0; i < filterFunctions.length; i++) {
@ -360,7 +324,7 @@ public class FiltersFunctionScoreQuery extends Query {
} }
break; break;
} }
return scoreCombiner.combine(subQueryScore, factor, maxBoost); return factor;
} }
} }

View File

@ -22,7 +22,6 @@ package org.elasticsearch.index.query.functionscore;
import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*; import org.apache.lucene.index.*;
import org.apache.lucene.search.*; import org.apache.lucene.search.*;
@ -47,7 +46,6 @@ import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.core.Is.is; import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsEqual.equalTo;
@ -374,9 +372,12 @@ public class FunctionScoreTests extends ESTestCase {
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException { public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions); FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
return getExplanation(searcher, filtersFunctionScoreQuery).getDetails()[1];
}
protected Explanation getExplanation(IndexSearcher searcher, FiltersFunctionScoreQuery filtersFunctionScoreQuery) throws IOException {
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true); Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0); return weight.explain(searcher.getIndexReader().leaves().get(0), 0);
return explanation.getDetails()[1];
} }
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) { public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
@ -430,7 +431,7 @@ public class FunctionScoreTests extends ESTestCase {
@Override @Override
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException { public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
throw new UnsupportedOperationException(UNSUPPORTED); return Explanation.match((float) score, "a random score for testing");
} }
}; };
} }
@ -473,6 +474,8 @@ public class FunctionScoreTests extends ESTestCase {
score *= weights[i] * scores[i]; score *= weights[i] * scores[i];
} }
assertThat(scoreWithWeight / (float) score, is(1f)); assertThat(scoreWithWeight / (float) score, is(1f));
float explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.SUM FiltersFunctionScoreQuery.ScoreMode.SUM
@ -487,6 +490,8 @@ public class FunctionScoreTests extends ESTestCase {
sum += weights[i] * scores[i]; sum += weights[i] * scores[i];
} }
assertThat(scoreWithWeight / (float) sum, is(1f)); assertThat(scoreWithWeight / (float) sum, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.AVG FiltersFunctionScoreQuery.ScoreMode.AVG
@ -503,6 +508,8 @@ public class FunctionScoreTests extends ESTestCase {
sum += weights[i] * scores[i]; sum += weights[i] * scores[i];
} }
assertThat(scoreWithWeight / (float) (sum / norm), is(1f)); assertThat(scoreWithWeight / (float) (sum / norm), is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MIN FiltersFunctionScoreQuery.ScoreMode.MIN
@ -517,6 +524,8 @@ public class FunctionScoreTests extends ESTestCase {
min = Math.min(min, weights[i] * scores[i]); min = Math.min(min, weights[i] * scores[i]);
} }
assertThat(scoreWithWeight / (float) min, is(1f)); assertThat(scoreWithWeight / (float) min, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MAX FiltersFunctionScoreQuery.ScoreMode.MAX
@ -531,6 +540,8 @@ public class FunctionScoreTests extends ESTestCase {
max = Math.max(max, weights[i] * scores[i]); max = Math.max(max, weights[i] * scores[i]);
} }
assertThat(scoreWithWeight / (float) max, is(1f)); assertThat(scoreWithWeight / (float) max, is(1f));
explainedScore = getExplanation(searcher, filtersFunctionScoreQueryWithWeights).getValue();
assertThat(explainedScore / scoreWithWeight, is(1f));
} }
@Test @Test