mirror of https://github.com/apache/lucene.git
LUCENE-8097: Implement maxScore() on disjunctions.
This commit is contained in:
parent
01023a95c8
commit
e3f90385b4
|
@ -62,6 +62,9 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
|
|||
*/
|
||||
public DisjunctionMaxQuery(Collection<Query> disjuncts, float tieBreakerMultiplier) {
|
||||
Objects.requireNonNull(disjuncts, "Collection of Querys must not be null");
|
||||
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
|
||||
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
|
||||
}
|
||||
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
||||
this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]);
|
||||
}
|
||||
|
@ -156,20 +159,25 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
|
|||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
boolean match = false;
|
||||
float max = Float.NEGATIVE_INFINITY;
|
||||
double sum = 0;
|
||||
float max = 0;
|
||||
double otherSum = 0;
|
||||
List<Explanation> subs = new ArrayList<>();
|
||||
for (Weight wt : weights) {
|
||||
Explanation e = wt.explain(context, doc);
|
||||
if (e.isMatch()) {
|
||||
match = true;
|
||||
subs.add(e);
|
||||
sum += e.getValue();
|
||||
max = Math.max(max, e.getValue());
|
||||
float score = e.getValue();
|
||||
if (score >= max) {
|
||||
otherSum += max;
|
||||
max = score;
|
||||
} else {
|
||||
otherSum += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
final float score = (float) (max + (sum - max) * tieBreakerMultiplier);
|
||||
final float score = (float) (max + otherSum * tieBreakerMultiplier);
|
||||
final String desc = tieBreakerMultiplier == 0.0f ? "max of:" : "max plus " + tieBreakerMultiplier + " times others of:";
|
||||
return Explanation.match(score, desc, subs);
|
||||
} else {
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.lucene.search;
|
|||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.util.MathUtil;
|
||||
|
||||
/**
|
||||
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
|
||||
* is generated in document number order. The score for each document is the maximum of the scores computed
|
||||
|
@ -28,6 +30,7 @@ import java.util.List;
|
|||
final class DisjunctionMaxScorer extends DisjunctionScorer {
|
||||
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
|
||||
private final float tieBreakerMultiplier;
|
||||
private final float maxScore;
|
||||
|
||||
/**
|
||||
* Creates a new instance of DisjunctionMaxScorer
|
||||
|
@ -43,25 +46,52 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
|
|||
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) {
|
||||
super(weight, subScorers, needsScores);
|
||||
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
||||
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
|
||||
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
|
||||
}
|
||||
|
||||
float scoreMax = 0;
|
||||
double otherScoreSum = 0;
|
||||
for (Scorer scorer : subScorers) {
|
||||
float subScore = scorer.maxScore();
|
||||
if (subScore >= scoreMax) {
|
||||
otherScoreSum += scoreMax;
|
||||
scoreMax = subScore;
|
||||
} else {
|
||||
otherScoreSum += subScore;
|
||||
}
|
||||
}
|
||||
|
||||
if (tieBreakerMultiplier == 0) {
|
||||
this.maxScore = scoreMax;
|
||||
} else {
|
||||
// The error of sums depends on the order in which values are summed up. In
|
||||
// order to avoid this issue, we compute an upper bound of the value that
|
||||
// the sum may take. If the max relative error is b, then it means that two
|
||||
// sums are always within 2*b of each other.
|
||||
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
|
||||
this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float score(DisiWrapper topList) throws IOException {
|
||||
double scoreSum = 0;
|
||||
float scoreMax = Float.NEGATIVE_INFINITY;
|
||||
float scoreMax = 0;
|
||||
double otherScoreSum = 0;
|
||||
for (DisiWrapper w = topList; w != null; w = w.next) {
|
||||
final float subScore = w.scorer.score();
|
||||
scoreSum += subScore;
|
||||
if (subScore > scoreMax) {
|
||||
float subScore = w.scorer.score();
|
||||
if (subScore >= scoreMax) {
|
||||
otherScoreSum += scoreMax;
|
||||
scoreMax = subScore;
|
||||
} else {
|
||||
otherScoreSum += subScore;
|
||||
}
|
||||
}
|
||||
return (float) (scoreMax + (scoreSum - scoreMax) * tieBreakerMultiplier);
|
||||
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxScore() {
|
||||
// TODO: implement but be careful about floating-point errors.
|
||||
return Float.POSITIVE_INFINITY;
|
||||
return maxScore;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,21 +20,36 @@ package org.apache.lucene.search;
|
|||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.util.MathUtil;
|
||||
|
||||
/** A Scorer for OR like queries, counterpart of <code>ConjunctionScorer</code>.
|
||||
*/
|
||||
final class DisjunctionSumScorer extends DisjunctionScorer {
|
||||
|
||||
|
||||
private final float maxScore;
|
||||
|
||||
/** Construct a <code>DisjunctionScorer</code>.
|
||||
* @param weight The weight to be used.
|
||||
* @param subScorers Array of at least two subscorers.
|
||||
*/
|
||||
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
|
||||
super(weight, subScorers, needsScores);
|
||||
double maxScore = 0;
|
||||
for (Scorer scorer : subScorers) {
|
||||
maxScore += scorer.maxScore();
|
||||
}
|
||||
// The error of sums depends on the order in which values are summed up. In
|
||||
// order to avoid this issue, we compute an upper bound of the value that
|
||||
// the sum may take. If the max relative error is b, then it means that two
|
||||
// sums are always within 2*b of each other.
|
||||
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(subScorers.size());
|
||||
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScore);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float score(DisiWrapper topList) throws IOException {
|
||||
double score = 0;
|
||||
|
||||
for (DisiWrapper w = topList; w != null; w = w.next) {
|
||||
score += w.scorer.score();
|
||||
}
|
||||
|
@ -43,8 +58,7 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
|
|||
|
||||
@Override
|
||||
public float maxScore() {
|
||||
// TODO: implement it but be careful with floating-point errors
|
||||
return Float.POSITIVE_INFINITY;
|
||||
return maxScore;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import java.util.Collection;
|
|||
import java.util.List;
|
||||
import java.util.OptionalInt;
|
||||
|
||||
import org.apache.lucene.util.MathUtil;
|
||||
|
||||
/**
|
||||
* This implements the WAND (Weak AND) algorithm for dynamic pruning
|
||||
* described in "Efficient Query Evaluation using a Two-Level Retrieval
|
||||
|
@ -120,6 +122,7 @@ final class WANDScorer extends Scorer {
|
|||
int tailSize;
|
||||
|
||||
final long cost;
|
||||
final float maxScore;
|
||||
|
||||
WANDScorer(Weight weight, Collection<Scorer> scorers) {
|
||||
super(weight);
|
||||
|
@ -142,10 +145,12 @@ final class WANDScorer extends Scorer {
|
|||
// Use a scaling factor of 0 if all max scores are either 0 or +Infty
|
||||
this.scalingFactor = scalingFactor.orElse(0);
|
||||
|
||||
double maxScoreSum = 0;
|
||||
for (Scorer scorer : scorers) {
|
||||
DisiWrapper w = new DisiWrapper(scorer);
|
||||
float maxScore = scorer.maxScore();
|
||||
w.maxScore = scaleMaxScore(maxScore, this.scalingFactor);
|
||||
maxScoreSum += maxScore;
|
||||
addLead(w);
|
||||
}
|
||||
|
||||
|
@ -154,6 +159,12 @@ final class WANDScorer extends Scorer {
|
|||
cost += w.cost;
|
||||
}
|
||||
this.cost = cost;
|
||||
// The error of sums depends on the order in which values are summed up. In
|
||||
// order to avoid this issue, we compute an upper bound of the value that
|
||||
// the sum may take. If the max relative error is b, then it means that two
|
||||
// sums are always within 2*b of each other.
|
||||
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(scorers.size());
|
||||
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScoreSum);
|
||||
}
|
||||
|
||||
// returns a boolean so that it can be called from assert
|
||||
|
@ -375,8 +386,7 @@ final class WANDScorer extends Scorer {
|
|||
|
||||
@Override
|
||||
public float maxScore() {
|
||||
// TODO: implement but be careful about floating-point errors.
|
||||
return Float.POSITIVE_INFINITY;
|
||||
return maxScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -149,5 +149,20 @@ public final class MathUtil {
|
|||
return mult * Math.log((1.0d + a) / (1.0d - a));
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a relative error bound for a sum of {@code numValues} positive doubles,
|
||||
* computed using recursive summation, ie. sum = x1 + ... + xn.
|
||||
* NOTE: This only works if all values are POSITIVE so that Σ |xi| == |Σ xi|.
|
||||
* This uses formula 3.5 from Higham, Nicholas J. (1993),
|
||||
* "The accuracy of floating point summation", SIAM Journal on Scientific Computing.
|
||||
*/
|
||||
public static double sumRelativeErrorBound(int numValues) {
|
||||
if (numValues <= 1) {
|
||||
return 0;
|
||||
}
|
||||
// u = unit roundoff in the paper, also called machine precision or machine epsilon
|
||||
double u = Math.scalb(1.0, -52);
|
||||
return (numValues - 1) * u;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
<DisjunctionMaxQuery>
|
||||
<TermQuery fieldName="a">merger</TermQuery>
|
||||
<DisjunctionMaxQuery tieBreaker="1.2">
|
||||
<DisjunctionMaxQuery tieBreaker="0.3">
|
||||
<TermQuery fieldName="b">verger</TermQuery>
|
||||
</DisjunctionMaxQuery>
|
||||
</DisjunctionMaxQuery>
|
||||
</DisjunctionMaxQuery>
|
||||
|
|
|
@ -102,7 +102,7 @@ public class TestCoreParser extends LuceneTestCase {
|
|||
assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f);
|
||||
assertEquals(2, d.getDisjuncts().size());
|
||||
DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1);
|
||||
assertEquals(1.2f, ndq.getTieBreakerMultiplier(), 0.0001f);
|
||||
assertEquals(0.3f, ndq.getTieBreakerMultiplier(), 0.0001f);
|
||||
assertEquals(1, ndq.getDisjuncts().size());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue