LUCENE-8097: Implement maxScore() on disjunctions.

This commit is contained in:
Adrien Grand 2017-12-26 14:19:26 +01:00
parent 01023a95c8
commit e3f90385b4
7 changed files with 98 additions and 21 deletions

View File

@ -62,6 +62,9 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
*/ */
public DisjunctionMaxQuery(Collection<Query> disjuncts, float tieBreakerMultiplier) { public DisjunctionMaxQuery(Collection<Query> disjuncts, float tieBreakerMultiplier) {
Objects.requireNonNull(disjuncts, "Collection of Querys must not be null"); Objects.requireNonNull(disjuncts, "Collection of Querys must not be null");
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
}
this.tieBreakerMultiplier = tieBreakerMultiplier; this.tieBreakerMultiplier = tieBreakerMultiplier;
this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]); this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]);
} }
@ -156,20 +159,25 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
@Override @Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException { public Explanation explain(LeafReaderContext context, int doc) throws IOException {
boolean match = false; boolean match = false;
float max = Float.NEGATIVE_INFINITY; float max = 0;
double sum = 0; double otherSum = 0;
List<Explanation> subs = new ArrayList<>(); List<Explanation> subs = new ArrayList<>();
for (Weight wt : weights) { for (Weight wt : weights) {
Explanation e = wt.explain(context, doc); Explanation e = wt.explain(context, doc);
if (e.isMatch()) { if (e.isMatch()) {
match = true; match = true;
subs.add(e); subs.add(e);
sum += e.getValue(); float score = e.getValue();
max = Math.max(max, e.getValue()); if (score >= max) {
otherSum += max;
max = score;
} else {
otherSum += score;
}
} }
} }
if (match) { if (match) {
final float score = (float) (max + (sum - max) * tieBreakerMultiplier); final float score = (float) (max + otherSum * tieBreakerMultiplier);
final String desc = tieBreakerMultiplier == 0.0f ? "max of:" : "max plus " + tieBreakerMultiplier + " times others of:"; final String desc = tieBreakerMultiplier == 0.0f ? "max of:" : "max plus " + tieBreakerMultiplier + " times others of:";
return Explanation.match(score, desc, subs); return Explanation.match(score, desc, subs);
} else { } else {

View File

@ -19,6 +19,8 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import org.apache.lucene.util.MathUtil;
/** /**
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers * The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
* is generated in document number order. The score for each document is the maximum of the scores computed * is generated in document number order. The score for each document is the maximum of the scores computed
@ -28,6 +30,7 @@ import java.util.List;
final class DisjunctionMaxScorer extends DisjunctionScorer { final class DisjunctionMaxScorer extends DisjunctionScorer {
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */ /* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
private final float tieBreakerMultiplier; private final float tieBreakerMultiplier;
private final float maxScore;
/** /**
* Creates a new instance of DisjunctionMaxScorer * Creates a new instance of DisjunctionMaxScorer
@ -43,25 +46,52 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) { DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) {
super(weight, subScorers, needsScores); super(weight, subScorers, needsScores);
this.tieBreakerMultiplier = tieBreakerMultiplier; this.tieBreakerMultiplier = tieBreakerMultiplier;
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
}
float scoreMax = 0;
double otherScoreSum = 0;
for (Scorer scorer : subScorers) {
float subScore = scorer.maxScore();
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
}
if (tieBreakerMultiplier == 0) {
this.maxScore = scoreMax;
} else {
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
}
} }
@Override @Override
protected float score(DisiWrapper topList) throws IOException { protected float score(DisiWrapper topList) throws IOException {
double scoreSum = 0; float scoreMax = 0;
float scoreMax = Float.NEGATIVE_INFINITY; double otherScoreSum = 0;
for (DisiWrapper w = topList; w != null; w = w.next) { for (DisiWrapper w = topList; w != null; w = w.next) {
final float subScore = w.scorer.score(); float subScore = w.scorer.score();
scoreSum += subScore; if (subScore >= scoreMax) {
if (subScore > scoreMax) { otherScoreSum += scoreMax;
scoreMax = subScore; scoreMax = subScore;
} else {
otherScoreSum += subScore;
} }
} }
return (float) (scoreMax + (scoreSum - scoreMax) * tieBreakerMultiplier); return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
} }
@Override @Override
public float maxScore() { public float maxScore() {
// TODO: implement but be careful about floating-point errors. return maxScore;
return Float.POSITIVE_INFINITY;
} }
} }

View File

@ -20,21 +20,36 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import org.apache.lucene.util.MathUtil;
/** A Scorer for OR like queries, counterpart of <code>ConjunctionScorer</code>. /** A Scorer for OR like queries, counterpart of <code>ConjunctionScorer</code>.
*/ */
final class DisjunctionSumScorer extends DisjunctionScorer { final class DisjunctionSumScorer extends DisjunctionScorer {
private final float maxScore;
/** Construct a <code>DisjunctionScorer</code>. /** Construct a <code>DisjunctionScorer</code>.
* @param weight The weight to be used. * @param weight The weight to be used.
* @param subScorers Array of at least two subscorers. * @param subScorers Array of at least two subscorers.
*/ */
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) { DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
super(weight, subScorers, needsScores); super(weight, subScorers, needsScores);
double maxScore = 0;
for (Scorer scorer : subScorers) {
maxScore += scorer.maxScore();
}
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(subScorers.size());
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScore);
} }
@Override @Override
protected float score(DisiWrapper topList) throws IOException { protected float score(DisiWrapper topList) throws IOException {
double score = 0; double score = 0;
for (DisiWrapper w = topList; w != null; w = w.next) { for (DisiWrapper w = topList; w != null; w = w.next) {
score += w.scorer.score(); score += w.scorer.score();
} }
@ -43,8 +58,7 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
@Override @Override
public float maxScore() { public float maxScore() {
// TODO: implement it but be careful with floating-point errors return maxScore;
return Float.POSITIVE_INFINITY;
} }
} }

View File

@ -26,6 +26,8 @@ import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.OptionalInt; import java.util.OptionalInt;
import org.apache.lucene.util.MathUtil;
/** /**
* This implements the WAND (Weak AND) algorithm for dynamic pruning * This implements the WAND (Weak AND) algorithm for dynamic pruning
* described in "Efficient Query Evaluation using a Two-Level Retrieval * described in "Efficient Query Evaluation using a Two-Level Retrieval
@ -120,6 +122,7 @@ final class WANDScorer extends Scorer {
int tailSize; int tailSize;
final long cost; final long cost;
final float maxScore;
WANDScorer(Weight weight, Collection<Scorer> scorers) { WANDScorer(Weight weight, Collection<Scorer> scorers) {
super(weight); super(weight);
@ -142,10 +145,12 @@ final class WANDScorer extends Scorer {
// Use a scaling factor of 0 if all max scores are either 0 or +Infty // Use a scaling factor of 0 if all max scores are either 0 or +Infty
this.scalingFactor = scalingFactor.orElse(0); this.scalingFactor = scalingFactor.orElse(0);
double maxScoreSum = 0;
for (Scorer scorer : scorers) { for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer); DisiWrapper w = new DisiWrapper(scorer);
float maxScore = scorer.maxScore(); float maxScore = scorer.maxScore();
w.maxScore = scaleMaxScore(maxScore, this.scalingFactor); w.maxScore = scaleMaxScore(maxScore, this.scalingFactor);
maxScoreSum += maxScore;
addLead(w); addLead(w);
} }
@ -154,6 +159,12 @@ final class WANDScorer extends Scorer {
cost += w.cost; cost += w.cost;
} }
this.cost = cost; this.cost = cost;
// The error of sums depends on the order in which values are summed up. In
// order to avoid this issue, we compute an upper bound of the value that
// the sum may take. If the max relative error is b, then it means that two
// sums are always within 2*b of each other.
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(scorers.size());
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScoreSum);
} }
// returns a boolean so that it can be called from assert // returns a boolean so that it can be called from assert
@ -375,8 +386,7 @@ final class WANDScorer extends Scorer {
@Override @Override
public float maxScore() { public float maxScore() {
// TODO: implement but be careful about floating-point errors. return maxScore;
return Float.POSITIVE_INFINITY;
} }
@Override @Override

View File

@ -149,5 +149,20 @@ public final class MathUtil {
return mult * Math.log((1.0d + a) / (1.0d - a)); return mult * Math.log((1.0d + a) / (1.0d - a));
} }
/**
* Return a relative error bound for a sum of {@code numValues} positive doubles,
* computed using recursive summation, ie. sum = x1 + ... + xn.
* NOTE: This only works if all values are POSITIVE so that Σ |xi| == |Σ xi|.
* This uses formula 3.5 from Higham, Nicholas J. (1993),
* "The accuracy of floating point summation", SIAM Journal on Scientific Computing.
*/
public static double sumRelativeErrorBound(int numValues) {
if (numValues <= 1) {
return 0;
}
// u = unit roundoff in the paper, also called machine precision or machine epsilon
double u = Math.scalb(1.0, -52);
return (numValues - 1) * u;
}
} }

View File

@ -18,7 +18,7 @@
<DisjunctionMaxQuery> <DisjunctionMaxQuery>
<TermQuery fieldName="a">merger</TermQuery> <TermQuery fieldName="a">merger</TermQuery>
<DisjunctionMaxQuery tieBreaker="1.2"> <DisjunctionMaxQuery tieBreaker="0.3">
<TermQuery fieldName="b">verger</TermQuery> <TermQuery fieldName="b">verger</TermQuery>
</DisjunctionMaxQuery> </DisjunctionMaxQuery>
</DisjunctionMaxQuery> </DisjunctionMaxQuery>

View File

@ -102,7 +102,7 @@ public class TestCoreParser extends LuceneTestCase {
assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f); assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f);
assertEquals(2, d.getDisjuncts().size()); assertEquals(2, d.getDisjuncts().size());
DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1); DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1);
assertEquals(1.2f, ndq.getTieBreakerMultiplier(), 0.0001f); assertEquals(0.3f, ndq.getTieBreakerMultiplier(), 0.0001f);
assertEquals(1, ndq.getDisjuncts().size()); assertEquals(1, ndq.getDisjuncts().size());
} }