mirror of https://github.com/apache/lucene.git
LUCENE-8097: Implement maxScore() on disjunctions.
This commit is contained in:
parent
01023a95c8
commit
e3f90385b4
|
@ -62,6 +62,9 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
|
||||||
*/
|
*/
|
||||||
public DisjunctionMaxQuery(Collection<Query> disjuncts, float tieBreakerMultiplier) {
|
public DisjunctionMaxQuery(Collection<Query> disjuncts, float tieBreakerMultiplier) {
|
||||||
Objects.requireNonNull(disjuncts, "Collection of Querys must not be null");
|
Objects.requireNonNull(disjuncts, "Collection of Querys must not be null");
|
||||||
|
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
|
||||||
|
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
|
||||||
|
}
|
||||||
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
||||||
this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]);
|
this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]);
|
||||||
}
|
}
|
||||||
|
@ -156,20 +159,25 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
|
||||||
@Override
|
@Override
|
||||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||||
boolean match = false;
|
boolean match = false;
|
||||||
float max = Float.NEGATIVE_INFINITY;
|
float max = 0;
|
||||||
double sum = 0;
|
double otherSum = 0;
|
||||||
List<Explanation> subs = new ArrayList<>();
|
List<Explanation> subs = new ArrayList<>();
|
||||||
for (Weight wt : weights) {
|
for (Weight wt : weights) {
|
||||||
Explanation e = wt.explain(context, doc);
|
Explanation e = wt.explain(context, doc);
|
||||||
if (e.isMatch()) {
|
if (e.isMatch()) {
|
||||||
match = true;
|
match = true;
|
||||||
subs.add(e);
|
subs.add(e);
|
||||||
sum += e.getValue();
|
float score = e.getValue();
|
||||||
max = Math.max(max, e.getValue());
|
if (score >= max) {
|
||||||
|
otherSum += max;
|
||||||
|
max = score;
|
||||||
|
} else {
|
||||||
|
otherSum += score;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (match) {
|
if (match) {
|
||||||
final float score = (float) (max + (sum - max) * tieBreakerMultiplier);
|
final float score = (float) (max + otherSum * tieBreakerMultiplier);
|
||||||
final String desc = tieBreakerMultiplier == 0.0f ? "max of:" : "max plus " + tieBreakerMultiplier + " times others of:";
|
final String desc = tieBreakerMultiplier == 0.0f ? "max of:" : "max plus " + tieBreakerMultiplier + " times others of:";
|
||||||
return Explanation.match(score, desc, subs);
|
return Explanation.match(score, desc, subs);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.lucene.search;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.MathUtil;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
|
* The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers
|
||||||
* is generated in document number order. The score for each document is the maximum of the scores computed
|
* is generated in document number order. The score for each document is the maximum of the scores computed
|
||||||
|
@ -28,6 +30,7 @@ import java.util.List;
|
||||||
final class DisjunctionMaxScorer extends DisjunctionScorer {
|
final class DisjunctionMaxScorer extends DisjunctionScorer {
|
||||||
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
|
/* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */
|
||||||
private final float tieBreakerMultiplier;
|
private final float tieBreakerMultiplier;
|
||||||
|
private final float maxScore;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new instance of DisjunctionMaxScorer
|
* Creates a new instance of DisjunctionMaxScorer
|
||||||
|
@ -43,25 +46,52 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
|
||||||
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) {
|
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) {
|
||||||
super(weight, subScorers, needsScores);
|
super(weight, subScorers, needsScores);
|
||||||
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
this.tieBreakerMultiplier = tieBreakerMultiplier;
|
||||||
|
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
|
||||||
|
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
float scoreMax = 0;
|
||||||
|
double otherScoreSum = 0;
|
||||||
|
for (Scorer scorer : subScorers) {
|
||||||
|
float subScore = scorer.maxScore();
|
||||||
|
if (subScore >= scoreMax) {
|
||||||
|
otherScoreSum += scoreMax;
|
||||||
|
scoreMax = subScore;
|
||||||
|
} else {
|
||||||
|
otherScoreSum += subScore;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tieBreakerMultiplier == 0) {
|
||||||
|
this.maxScore = scoreMax;
|
||||||
|
} else {
|
||||||
|
// The error of sums depends on the order in which values are summed up. In
|
||||||
|
// order to avoid this issue, we compute an upper bound of the value that
|
||||||
|
// the sum may take. If the max relative error is b, then it means that two
|
||||||
|
// sums are always within 2*b of each other.
|
||||||
|
otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1));
|
||||||
|
this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected float score(DisiWrapper topList) throws IOException {
|
protected float score(DisiWrapper topList) throws IOException {
|
||||||
double scoreSum = 0;
|
float scoreMax = 0;
|
||||||
float scoreMax = Float.NEGATIVE_INFINITY;
|
double otherScoreSum = 0;
|
||||||
for (DisiWrapper w = topList; w != null; w = w.next) {
|
for (DisiWrapper w = topList; w != null; w = w.next) {
|
||||||
final float subScore = w.scorer.score();
|
float subScore = w.scorer.score();
|
||||||
scoreSum += subScore;
|
if (subScore >= scoreMax) {
|
||||||
if (subScore > scoreMax) {
|
otherScoreSum += scoreMax;
|
||||||
scoreMax = subScore;
|
scoreMax = subScore;
|
||||||
|
} else {
|
||||||
|
otherScoreSum += subScore;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return (float) (scoreMax + (scoreSum - scoreMax) * tieBreakerMultiplier);
|
return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float maxScore() {
|
public float maxScore() {
|
||||||
// TODO: implement but be careful about floating-point errors.
|
return maxScore;
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,21 +20,36 @@ package org.apache.lucene.search;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.MathUtil;
|
||||||
|
|
||||||
/** A Scorer for OR like queries, counterpart of <code>ConjunctionScorer</code>.
|
/** A Scorer for OR like queries, counterpart of <code>ConjunctionScorer</code>.
|
||||||
*/
|
*/
|
||||||
final class DisjunctionSumScorer extends DisjunctionScorer {
|
final class DisjunctionSumScorer extends DisjunctionScorer {
|
||||||
|
|
||||||
|
private final float maxScore;
|
||||||
|
|
||||||
/** Construct a <code>DisjunctionScorer</code>.
|
/** Construct a <code>DisjunctionScorer</code>.
|
||||||
* @param weight The weight to be used.
|
* @param weight The weight to be used.
|
||||||
* @param subScorers Array of at least two subscorers.
|
* @param subScorers Array of at least two subscorers.
|
||||||
*/
|
*/
|
||||||
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
|
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
|
||||||
super(weight, subScorers, needsScores);
|
super(weight, subScorers, needsScores);
|
||||||
|
double maxScore = 0;
|
||||||
|
for (Scorer scorer : subScorers) {
|
||||||
|
maxScore += scorer.maxScore();
|
||||||
|
}
|
||||||
|
// The error of sums depends on the order in which values are summed up. In
|
||||||
|
// order to avoid this issue, we compute an upper bound of the value that
|
||||||
|
// the sum may take. If the max relative error is b, then it means that two
|
||||||
|
// sums are always within 2*b of each other.
|
||||||
|
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(subScorers.size());
|
||||||
|
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScore);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected float score(DisiWrapper topList) throws IOException {
|
protected float score(DisiWrapper topList) throws IOException {
|
||||||
double score = 0;
|
double score = 0;
|
||||||
|
|
||||||
for (DisiWrapper w = topList; w != null; w = w.next) {
|
for (DisiWrapper w = topList; w != null; w = w.next) {
|
||||||
score += w.scorer.score();
|
score += w.scorer.score();
|
||||||
}
|
}
|
||||||
|
@ -43,8 +58,7 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float maxScore() {
|
public float maxScore() {
|
||||||
// TODO: implement it but be careful with floating-point errors
|
return maxScore;
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.OptionalInt;
|
import java.util.OptionalInt;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.MathUtil;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This implements the WAND (Weak AND) algorithm for dynamic pruning
|
* This implements the WAND (Weak AND) algorithm for dynamic pruning
|
||||||
* described in "Efficient Query Evaluation using a Two-Level Retrieval
|
* described in "Efficient Query Evaluation using a Two-Level Retrieval
|
||||||
|
@ -120,6 +122,7 @@ final class WANDScorer extends Scorer {
|
||||||
int tailSize;
|
int tailSize;
|
||||||
|
|
||||||
final long cost;
|
final long cost;
|
||||||
|
final float maxScore;
|
||||||
|
|
||||||
WANDScorer(Weight weight, Collection<Scorer> scorers) {
|
WANDScorer(Weight weight, Collection<Scorer> scorers) {
|
||||||
super(weight);
|
super(weight);
|
||||||
|
@ -142,10 +145,12 @@ final class WANDScorer extends Scorer {
|
||||||
// Use a scaling factor of 0 if all max scores are either 0 or +Infty
|
// Use a scaling factor of 0 if all max scores are either 0 or +Infty
|
||||||
this.scalingFactor = scalingFactor.orElse(0);
|
this.scalingFactor = scalingFactor.orElse(0);
|
||||||
|
|
||||||
|
double maxScoreSum = 0;
|
||||||
for (Scorer scorer : scorers) {
|
for (Scorer scorer : scorers) {
|
||||||
DisiWrapper w = new DisiWrapper(scorer);
|
DisiWrapper w = new DisiWrapper(scorer);
|
||||||
float maxScore = scorer.maxScore();
|
float maxScore = scorer.maxScore();
|
||||||
w.maxScore = scaleMaxScore(maxScore, this.scalingFactor);
|
w.maxScore = scaleMaxScore(maxScore, this.scalingFactor);
|
||||||
|
maxScoreSum += maxScore;
|
||||||
addLead(w);
|
addLead(w);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,6 +159,12 @@ final class WANDScorer extends Scorer {
|
||||||
cost += w.cost;
|
cost += w.cost;
|
||||||
}
|
}
|
||||||
this.cost = cost;
|
this.cost = cost;
|
||||||
|
// The error of sums depends on the order in which values are summed up. In
|
||||||
|
// order to avoid this issue, we compute an upper bound of the value that
|
||||||
|
// the sum may take. If the max relative error is b, then it means that two
|
||||||
|
// sums are always within 2*b of each other.
|
||||||
|
double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(scorers.size());
|
||||||
|
this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScoreSum);
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns a boolean so that it can be called from assert
|
// returns a boolean so that it can be called from assert
|
||||||
|
@ -375,8 +386,7 @@ final class WANDScorer extends Scorer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float maxScore() {
|
public float maxScore() {
|
||||||
// TODO: implement but be careful about floating-point errors.
|
return maxScore;
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -149,5 +149,20 @@ public final class MathUtil {
|
||||||
return mult * Math.log((1.0d + a) / (1.0d - a));
|
return mult * Math.log((1.0d + a) / (1.0d - a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a relative error bound for a sum of {@code numValues} positive doubles,
|
||||||
|
* computed using recursive summation, ie. sum = x1 + ... + xn.
|
||||||
|
* NOTE: This only works if all values are POSITIVE so that Σ |xi| == |Σ xi|.
|
||||||
|
* This uses formula 3.5 from Higham, Nicholas J. (1993),
|
||||||
|
* "The accuracy of floating point summation", SIAM Journal on Scientific Computing.
|
||||||
|
*/
|
||||||
|
public static double sumRelativeErrorBound(int numValues) {
|
||||||
|
if (numValues <= 1) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
// u = unit roundoff in the paper, also called machine precision or machine epsilon
|
||||||
|
double u = Math.scalb(1.0, -52);
|
||||||
|
return (numValues - 1) * u;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
<DisjunctionMaxQuery>
|
<DisjunctionMaxQuery>
|
||||||
<TermQuery fieldName="a">merger</TermQuery>
|
<TermQuery fieldName="a">merger</TermQuery>
|
||||||
<DisjunctionMaxQuery tieBreaker="1.2">
|
<DisjunctionMaxQuery tieBreaker="0.3">
|
||||||
<TermQuery fieldName="b">verger</TermQuery>
|
<TermQuery fieldName="b">verger</TermQuery>
|
||||||
</DisjunctionMaxQuery>
|
</DisjunctionMaxQuery>
|
||||||
</DisjunctionMaxQuery>
|
</DisjunctionMaxQuery>
|
||||||
|
|
|
@ -102,7 +102,7 @@ public class TestCoreParser extends LuceneTestCase {
|
||||||
assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f);
|
assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f);
|
||||||
assertEquals(2, d.getDisjuncts().size());
|
assertEquals(2, d.getDisjuncts().size());
|
||||||
DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1);
|
DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1);
|
||||||
assertEquals(1.2f, ndq.getTieBreakerMultiplier(), 0.0001f);
|
assertEquals(0.3f, ndq.getTieBreakerMultiplier(), 0.0001f);
|
||||||
assertEquals(1, ndq.getDisjuncts().size());
|
assertEquals(1, ndq.getDisjuncts().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue