LUCENE-10553: Fix WANDScorer's handling of 0 and +Infty. (#860)

The computation of the scaling factor has special cases for these two values,
but the current logic is backwards.
This commit is contained in:
Adrien Grand 2022-05-05 10:24:28 +01:00 committed by GitHub
parent a89c57f35f
commit 26301898b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 21 deletions

View File

@ -60,17 +60,17 @@ final class WANDScorer extends Scorer {
* {@code [2^23, 2^24[}. Special cases: * {@code [2^23, 2^24[}. Special cases:
* *
* <pre> * <pre>
* scalingFactor(0) = scalingFactor(MIN_VALUE) - 1 * scalingFactor(0) = scalingFactor(MIN_VALUE) + 1
* scalingFactor(+Infty) = scalingFactor(MAX_VALUE) + 1 * scalingFactor(+Infty) = scalingFactor(MAX_VALUE) - 1
* </pre> * </pre>
*/ */
static int scalingFactor(float f) { static int scalingFactor(float f) {
if (f < 0) { if (f < 0) {
throw new IllegalArgumentException("Scores must be positive or null"); throw new IllegalArgumentException("Scores must be positive or null");
} else if (f == 0) { } else if (f == 0) {
return scalingFactor(Float.MIN_VALUE) - 1; return scalingFactor(Float.MIN_VALUE) + 1;
} else if (Float.isInfinite(f)) { } else if (Float.isInfinite(f)) {
return scalingFactor(Float.MAX_VALUE) + 1; return scalingFactor(Float.MAX_VALUE) - 1;
} else { } else {
double d = f; double d = f;
// Since doubles have more amplitude than floats for the // Since doubles have more amplitude than floats for the
@ -86,7 +86,6 @@ final class WANDScorer extends Scorer {
* sure we do not miss any matches. * sure we do not miss any matches.
*/ */
static long scaleMaxScore(float maxScore, int scalingFactor) { static long scaleMaxScore(float maxScore, int scalingFactor) {
assert scalingFactor(maxScore) >= scalingFactor;
assert Float.isNaN(maxScore) == false; assert Float.isNaN(maxScore) == false;
assert maxScore >= 0; assert maxScore >= 0;
@ -95,7 +94,8 @@ final class WANDScorer extends Scorer {
final double scaled = Math.scalb((double) maxScore, scalingFactor); final double scaled = Math.scalb((double) maxScore, scalingFactor);
if (scaled > MAX_SCALED_SCORE) { if (scaled > MAX_SCALED_SCORE) {
// This happens if one scorer returns +Infty as a max score // This happens if one scorer returns +Infty as a max score, or if the scorer returns greater
// max scores locally than globally - which shouldn't happen with well-behaved scorers
return MAX_SCALED_SCORE; return MAX_SCALED_SCORE;
} }

View File

@ -45,10 +45,17 @@ public class TestWANDScorer extends LuceneTestCase {
doTestScalingFactor(Math.nextUp(Float.MIN_VALUE)); doTestScalingFactor(Math.nextUp(Float.MIN_VALUE));
doTestScalingFactor(Float.MAX_VALUE); doTestScalingFactor(Float.MAX_VALUE);
doTestScalingFactor(Math.nextDown(Float.MAX_VALUE)); doTestScalingFactor(Math.nextDown(Float.MAX_VALUE));
assertEquals(WANDScorer.scalingFactor(Float.MIN_VALUE) - 1, WANDScorer.scalingFactor(0)); assertEquals(WANDScorer.scalingFactor(Float.MIN_VALUE) + 1, WANDScorer.scalingFactor(0));
assertEquals( assertEquals(
WANDScorer.scalingFactor(Float.MAX_VALUE) + 1, WANDScorer.scalingFactor(Float.MAX_VALUE) - 1,
WANDScorer.scalingFactor(Float.POSITIVE_INFINITY)); WANDScorer.scalingFactor(Float.POSITIVE_INFINITY));
// Greater scores produce lower scaling factors
assertTrue(WANDScorer.scalingFactor(1f) > WANDScorer.scalingFactor(10f));
assertTrue(
WANDScorer.scalingFactor(Float.MAX_VALUE)
> WANDScorer.scalingFactor(Float.POSITIVE_INFINITY));
assertTrue(WANDScorer.scalingFactor(0f) > WANDScorer.scalingFactor(Float.MIN_VALUE));
} }
private void doTestScalingFactor(float f) { private void doTestScalingFactor(float f) {
@ -720,7 +727,65 @@ public class TestWANDScorer extends LuceneTestCase {
dir.close(); dir.close();
} }
/** Degenerate case: all clauses produce a score of 0. */
public void testRandomWithZeroScores() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
int numDocs = atLeast(1000);
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
int numValues = random().nextInt(1 << random().nextInt(5));
int start = random().nextInt(10);
for (int j = 0; j < numValues; ++j) {
doc.add(new StringField("foo", Integer.toString(start + j), Store.NO));
}
w.addDocument(doc);
}
IndexReader reader = DirectoryReader.open(w);
w.close();
IndexSearcher searcher = newSearcher(reader);
for (int iter = 0; iter < 100; ++iter) {
int start = random().nextInt(10);
int numClauses = random().nextInt(1 << random().nextInt(5));
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (int i = 0; i < numClauses; ++i) {
builder.add(
maybeWrap(
new BoostQuery(
new ConstantScoreQuery(
new TermQuery(new Term("foo", Integer.toString(start + i)))),
0f)),
Occur.SHOULD);
}
Query query = builder.build();
CheckHits.checkTopScores(random(), query, searcher);
int filterTerm = random().nextInt(30);
Query filteredQuery =
new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(new TermQuery(new Term("foo", Integer.toString(filterTerm))), Occur.FILTER)
.build();
CheckHits.checkTopScores(random(), filteredQuery, searcher);
}
reader.close();
dir.close();
}
/** Test the case when some clauses produce infinite max scores. */
public void testRandomWithInfiniteMaxScore() throws IOException { public void testRandomWithInfiniteMaxScore() throws IOException {
doTestRandomSpecialMaxScore(Float.POSITIVE_INFINITY);
}
/** Test the case when some clauses produce finite max scores, but their sum overflows. */
public void testRandomWithMaxScoreOverflow() throws IOException {
doTestRandomSpecialMaxScore(Float.MAX_VALUE);
}
private void doTestRandomSpecialMaxScore(float maxScore) throws IOException {
Directory dir = newDirectory(); Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
int numDocs = atLeast(1000); int numDocs = atLeast(1000);
@ -745,7 +810,8 @@ public class TestWANDScorer extends LuceneTestCase {
Query query = new TermQuery(new Term("foo", Integer.toString(start + i))); Query query = new TermQuery(new Term("foo", Integer.toString(start + i)));
if (random().nextBoolean()) { if (random().nextBoolean()) {
query = query =
new InfiniteMaxScoreWrapperQuery(query, numDocs / TestUtil.nextInt(random(), 1, 100)); new MaxScoreWrapperQuery(
query, numDocs / TestUtil.nextInt(random(), 1, 100), maxScore);
} }
builder.add(query, Occur.SHOULD); builder.add(query, Occur.SHOULD);
} }
@ -766,14 +832,16 @@ public class TestWANDScorer extends LuceneTestCase {
dir.close(); dir.close();
} }
private static class InfiniteMaxScoreWrapperScorer extends FilterScorer { private static class MaxScoreWrapperScorer extends FilterScorer {
private final int maxRange; private final int maxRange;
private final float maxScore;
private int lastShallowTarget = -1; private int lastShallowTarget = -1;
InfiniteMaxScoreWrapperScorer(Scorer scorer, int maxRange) { MaxScoreWrapperScorer(Scorer scorer, int maxRange, float maxScore) {
super(scorer); super(scorer);
this.maxRange = maxRange; this.maxRange = maxRange;
this.maxScore = maxScore;
} }
@Override @Override
@ -785,24 +853,26 @@ public class TestWANDScorer extends LuceneTestCase {
@Override @Override
public float getMaxScore(int upTo) throws IOException { public float getMaxScore(int upTo) throws IOException {
if (upTo - Math.max(docID(), lastShallowTarget) >= maxRange) { if (upTo - Math.max(docID(), lastShallowTarget) >= maxRange) {
return Float.POSITIVE_INFINITY; return maxScore;
} }
return in.getMaxScore(upTo); return in.getMaxScore(upTo);
} }
} }
private static class InfiniteMaxScoreWrapperQuery extends Query { private static class MaxScoreWrapperQuery extends Query {
private final Query query; private final Query query;
private final int maxRange; private final int maxRange;
private final float maxScore;
/** /**
* If asked for the maximum score over a range of doc IDs that is greater than or equal to * If asked for the maximum score over a range of doc IDs that is greater than or equal to
* maxRange, this query will return a maximum score of +Infty * maxRange, this query will return the provided maxScore.
*/ */
InfiniteMaxScoreWrapperQuery(Query query, int maxRange) { MaxScoreWrapperQuery(Query query, int maxRange, float maxScore) {
this.query = query; this.query = query;
this.maxRange = maxRange; this.maxRange = maxRange;
this.maxScore = maxScore;
} }
@Override @Override
@ -812,19 +882,27 @@ public class TestWANDScorer extends LuceneTestCase {
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
return sameClassAs(obj) && query.equals(((InfiniteMaxScoreWrapperQuery) obj).query); if (sameClassAs(obj) == false) {
return false;
}
MaxScoreWrapperQuery that = (MaxScoreWrapperQuery) obj;
return query.equals(that.query) && maxRange == that.maxRange && maxScore == that.maxScore;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return 31 * classHash() + query.hashCode(); int hash = classHash();
hash = 31 * hash + query.hashCode();
hash = 31 * hash + Integer.hashCode(maxRange);
hash = 31 * hash + Float.hashCode(maxScore);
return hash;
} }
@Override @Override
public Query rewrite(IndexReader reader) throws IOException { public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = query.rewrite(reader); Query rewritten = query.rewrite(reader);
if (rewritten != query) { if (rewritten != query) {
return new InfiniteMaxScoreWrapperQuery(rewritten, maxRange); return new MaxScoreWrapperQuery(rewritten, maxRange, maxScore);
} }
return super.rewrite(reader); return super.rewrite(reader);
} }
@ -842,7 +920,7 @@ public class TestWANDScorer extends LuceneTestCase {
if (scorer == null) { if (scorer == null) {
return null; return null;
} else { } else {
return new InfiniteMaxScoreWrapperScorer(scorer, maxRange); return new MaxScoreWrapperScorer(scorer, maxRange, maxScore);
} }
} }
@ -856,7 +934,7 @@ public class TestWANDScorer extends LuceneTestCase {
@Override @Override
public Scorer get(long leadCost) throws IOException { public Scorer get(long leadCost) throws IOException {
return new InfiniteMaxScoreWrapperScorer(supplier.get(leadCost), maxRange); return new MaxScoreWrapperScorer(supplier.get(leadCost), maxRange, maxScore);
} }
@Override @Override