From e3f90385b4928e3639ee09a907df323e452c74de Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Tue, 26 Dec 2017 14:19:26 +0100 Subject: [PATCH] LUCENE-8097: Implement maxScore() on disjunctions. --- .../lucene/search/DisjunctionMaxQuery.java | 18 ++++++-- .../lucene/search/DisjunctionMaxScorer.java | 46 +++++++++++++++---- .../lucene/search/DisjunctionSumScorer.java | 20 ++++++-- .../org/apache/lucene/search/WANDScorer.java | 14 +++++- .../java/org/apache/lucene/util/MathUtil.java | 15 ++++++ .../queryparser/xml/DisjunctionMaxQuery.xml | 4 +- .../queryparser/xml/TestCoreParser.java | 2 +- 7 files changed, 98 insertions(+), 21 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java index 97c02a64c2e..3285bafccd1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java @@ -62,6 +62,9 @@ public final class DisjunctionMaxQuery extends Query implements Iterable */ public DisjunctionMaxQuery(Collection disjuncts, float tieBreakerMultiplier) { Objects.requireNonNull(disjuncts, "Collection of Querys must not be null"); + if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) { + throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]"); + } this.tieBreakerMultiplier = tieBreakerMultiplier; this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]); } @@ -156,20 +159,25 @@ public final class DisjunctionMaxQuery extends Query implements Iterable @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { boolean match = false; - float max = Float.NEGATIVE_INFINITY; - double sum = 0; + float max = 0; + double otherSum = 0; List subs = new ArrayList<>(); for (Weight wt : weights) { Explanation e = wt.explain(context, doc); if (e.isMatch()) { match = true; subs.add(e); - sum += e.getValue(); - max = Math.max(max, e.getValue()); + float score = e.getValue(); + if (score >= max) { + otherSum += max; + max = score; + } else { + otherSum += score; + } } } if (match) { - final float score = (float) (max + (sum - max) * tieBreakerMultiplier); + final float score = (float) (max + otherSum * tieBreakerMultiplier); final String desc = tieBreakerMultiplier == 0.0f ? "max of:" : "max plus " + tieBreakerMultiplier + " times others of:"; return Explanation.match(score, desc, subs); } else { diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java index 084de668f81..c5c3640e147 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java @@ -19,6 +19,8 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.List; +import org.apache.lucene.util.MathUtil; + /** * The Scorer for DisjunctionMaxQuery. The union of all documents generated by the the subquery scorers * is generated in document number order. The score for each document is the maximum of the scores computed @@ -28,6 +30,7 @@ import java.util.List; final class DisjunctionMaxScorer extends DisjunctionScorer { /* Multiplier applied to non-maximum-scoring subqueries for a document as they are summed into the result. */ private final float tieBreakerMultiplier; + private final float maxScore; /** * Creates a new instance of DisjunctionMaxScorer @@ -43,25 +46,52 @@ final class DisjunctionMaxScorer extends DisjunctionScorer { DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List subScorers, boolean needsScores) { super(weight, subScorers, needsScores); this.tieBreakerMultiplier = tieBreakerMultiplier; + if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) { + throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]"); + } + + float scoreMax = 0; + double otherScoreSum = 0; + for (Scorer scorer : subScorers) { + float subScore = scorer.maxScore(); + if (subScore >= scoreMax) { + otherScoreSum += scoreMax; + scoreMax = subScore; + } else { + otherScoreSum += subScore; + } + } + + if (tieBreakerMultiplier == 0) { + this.maxScore = scoreMax; + } else { + // The error of sums depends on the order in which values are summed up. In + // order to avoid this issue, we compute an upper bound of the value that + // the sum may take. If the max relative error is b, then it means that two + // sums are always within 2*b of each other. + otherScoreSum *= (1 + 2 * MathUtil.sumRelativeErrorBound(subScorers.size() - 1)); + this.maxScore = (float) (scoreMax + otherScoreSum * tieBreakerMultiplier); + } } @Override protected float score(DisiWrapper topList) throws IOException { - double scoreSum = 0; - float scoreMax = Float.NEGATIVE_INFINITY; + float scoreMax = 0; + double otherScoreSum = 0; for (DisiWrapper w = topList; w != null; w = w.next) { - final float subScore = w.scorer.score(); - scoreSum += subScore; - if (subScore > scoreMax) { + float subScore = w.scorer.score(); + if (subScore >= scoreMax) { + otherScoreSum += scoreMax; scoreMax = subScore; + } else { + otherScoreSum += subScore; } } - return (float) (scoreMax + (scoreSum - scoreMax) * tieBreakerMultiplier); + return (float) (scoreMax + otherScoreSum * tieBreakerMultiplier); } @Override public float maxScore() { - // TODO: implement but be careful about floating-point errors. - return Float.POSITIVE_INFINITY; + return maxScore; } } diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java index 729a2986812..7e22991dacb 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java @@ -20,21 +20,36 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.List; +import org.apache.lucene.util.MathUtil; + /** A Scorer for OR like queries, counterpart of ConjunctionScorer. */ final class DisjunctionSumScorer extends DisjunctionScorer { - + + private final float maxScore; + /** Construct a DisjunctionScorer. * @param weight The weight to be used. * @param subScorers Array of at least two subscorers. */ DisjunctionSumScorer(Weight weight, List subScorers, boolean needsScores) { super(weight, subScorers, needsScores); + double maxScore = 0; + for (Scorer scorer : subScorers) { + maxScore += scorer.maxScore(); + } + // The error of sums depends on the order in which values are summed up. In + // order to avoid this issue, we compute an upper bound of the value that + // the sum may take. If the max relative error is b, then it means that two + // sums are always within 2*b of each other. + double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(subScorers.size()); + this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScore); } @Override protected float score(DisiWrapper topList) throws IOException { double score = 0; + for (DisiWrapper w = topList; w != null; w = w.next) { score += w.scorer.score(); } @@ -43,8 +58,7 @@ final class DisjunctionSumScorer extends DisjunctionScorer { @Override public float maxScore() { - // TODO: implement it but be careful with floating-point errors - return Float.POSITIVE_INFINITY; + return maxScore; } } diff --git a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java index 2f3b600081c..f5f647e3617 100644 --- a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java @@ -26,6 +26,8 @@ import java.util.Collection; import java.util.List; import java.util.OptionalInt; +import org.apache.lucene.util.MathUtil; + /** * This implements the WAND (Weak AND) algorithm for dynamic pruning * described in "Efficient Query Evaluation using a Two-Level Retrieval @@ -120,6 +122,7 @@ final class WANDScorer extends Scorer { int tailSize; final long cost; + final float maxScore; WANDScorer(Weight weight, Collection scorers) { super(weight); @@ -142,10 +145,12 @@ final class WANDScorer extends Scorer { // Use a scaling factor of 0 if all max scores are either 0 or +Infty this.scalingFactor = scalingFactor.orElse(0); + double maxScoreSum = 0; for (Scorer scorer : scorers) { DisiWrapper w = new DisiWrapper(scorer); float maxScore = scorer.maxScore(); w.maxScore = scaleMaxScore(maxScore, this.scalingFactor); + maxScoreSum += maxScore; addLead(w); } @@ -154,6 +159,12 @@ final class WANDScorer extends Scorer { cost += w.cost; } this.cost = cost; + // The error of sums depends on the order in which values are summed up. In + // order to avoid this issue, we compute an upper bound of the value that + // the sum may take. If the max relative error is b, then it means that two + // sums are always within 2*b of each other. + double maxScoreRelativeErrorBound = MathUtil.sumRelativeErrorBound(scorers.size()); + this.maxScore = (float) ((1.0 + 2 * maxScoreRelativeErrorBound) * maxScoreSum); } // returns a boolean so that it can be called from assert @@ -375,8 +386,7 @@ final class WANDScorer extends Scorer { @Override public float maxScore() { - // TODO: implement but be careful about floating-point errors. - return Float.POSITIVE_INFINITY; + return maxScore; } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/MathUtil.java b/lucene/core/src/java/org/apache/lucene/util/MathUtil.java index 09437fe1509..7430c5d8a86 100644 --- a/lucene/core/src/java/org/apache/lucene/util/MathUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/MathUtil.java @@ -149,5 +149,20 @@ public final class MathUtil { return mult * Math.log((1.0d + a) / (1.0d - a)); } + /** + * Return a relative error bound for a sum of {@code numValues} positive doubles, + * computed using recursive summation, ie. sum = x1 + ... + xn. + * NOTE: This only works if all values are POSITIVE so that Σ |xi| == |Σ xi|. + * This uses formula 3.5 from Higham, Nicholas J. (1993), + * "The accuracy of floating point summation", SIAM Journal on Scientific Computing. + */ + public static double sumRelativeErrorBound(int numValues) { + if (numValues <= 1) { + return 0; + } + // u = unit roundoff in the paper, also called machine precision or machine epsilon + double u = Math.scalb(1.0, -52); + return (numValues - 1) * u; + } } diff --git a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/DisjunctionMaxQuery.xml b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/DisjunctionMaxQuery.xml index ebf1400d6f6..0c94b004a66 100644 --- a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/DisjunctionMaxQuery.xml +++ b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/DisjunctionMaxQuery.xml @@ -18,7 +18,7 @@ merger - + verger - \ No newline at end of file + diff --git a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java index d97e2f64955..b9e44c1ba37 100644 --- a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java +++ b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java @@ -102,7 +102,7 @@ public class TestCoreParser extends LuceneTestCase { assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f); assertEquals(2, d.getDisjuncts().size()); DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1); - assertEquals(1.2f, ndq.getTieBreakerMultiplier(), 0.0001f); + assertEquals(0.3f, ndq.getTieBreakerMultiplier(), 0.0001f); assertEquals(1, ndq.getDisjuncts().size()); }