BlockJoinBulkScorer must check for parent deletions (not children)

This change ensures that BlockJoinBulkScorer verifies deletions at the parent level.
This commit is contained in:
Jim Ferenczi 2024-12-15 09:06:53 +00:00
parent a8d8d6b3d9
commit 4e40487cb8
2 changed files with 35 additions and 17 deletions

View File

@ -532,8 +532,9 @@ public class ToParentBlockJoinQuery extends Query {
return scoringCompleteCheck(max, max);
}
BatchAwareLeafCollector wrappedCollector = wrapCollector(collector);
childBulkScorer.score(wrappedCollector, acceptDocs, prevParent + 1, lastParent + 1);
BatchAwareLeafCollector wrappedCollector = wrapCollector(collector, acceptDocs);
// We don't propagate the acceptDocs since only parents are checked for deletion in the wrapped collector
childBulkScorer.score(wrappedCollector, null, prevParent + 1, lastParent + 1);
wrappedCollector.endBatch();
return scoringCompleteCheck(lastParent + 1, max);
@ -550,7 +551,7 @@ public class ToParentBlockJoinQuery extends Query {
return childBulkScorer.cost();
}
private BatchAwareLeafCollector wrapCollector(LeafCollector collector) {
private BatchAwareLeafCollector wrapCollector(LeafCollector collector, Bits acceptDocs) {
return new BatchAwareLeafCollector(collector) {
private final Score currentParentScore = new Score(scoreMode);
private int currentParent = -1;
@ -581,7 +582,7 @@ public class ToParentBlockJoinQuery extends Query {
public void collect(int doc) throws IOException {
if (doc > currentParent) {
// Emit the current parent and setup scoring for the next parent
if (currentParent >= 0) {
if (currentParent >= 0 && (acceptDocs == null || acceptDocs.get(currentParent))) {
in.collect(currentParent);
}
@ -602,7 +603,7 @@ public class ToParentBlockJoinQuery extends Query {
@Override
public void endBatch() throws IOException {
if (currentParent >= 0) {
if (currentParent >= 0 && (acceptDocs == null || acceptDocs.get(currentParent))) {
in.collect(currentParent);
}
}

View File

@ -49,6 +49,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.FixedBitSet;
public class TestBlockJoinBulkScorer extends LuceneTestCase {
private static final String TYPE_FIELD_NAME = "type";
@ -256,21 +257,36 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
}
private static void assertScores(
int maxDoc,
BulkScorer bulkScorer,
org.apache.lucene.search.ScoreMode scoreMode,
Float minScore,
Map<Integer, Float> expectedScores)
throws IOException {
assertScores(bulkScorer, scoreMode, minScore, List.of(expectedScores));
assertScores(maxDoc, bulkScorer, scoreMode, minScore, List.of(expectedScores));
}
private static void assertScores(
int maxDoc,
BulkScorer bulkScorer,
org.apache.lucene.search.ScoreMode scoreMode,
Float minScore,
List<Map<Integer, Float>> expectedScoresList)
throws IOException {
Map<Integer, Float> actualScores = new HashMap<>();
FixedBitSet acceptDocs = new FixedBitSet(maxDoc);
List<Map<Integer, Float>> expectedScoresListPruned = new ArrayList<>();
for (var map : expectedScoresList) {
Map<Integer, Float> newMap = new HashMap<>();
for (var entry : map.entrySet()) {
if (usually(random())) {
acceptDocs.set(entry.getKey());
newMap.put(entry.getKey(), entry.getValue());
}
}
expectedScoresListPruned.add(newMap);
}
bulkScorer.score(
new LeafCollector() {
private Scorable scorer;
@ -286,18 +302,19 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
@Override
public void collect(int doc) throws IOException {
assertTrue(acceptDocs.get(doc));
assertNotNull(scorer);
actualScores.put(doc, scoreMode.needsScores() ? scorer.score() : 0);
}
},
null,
acceptDocs,
0,
NO_MORE_DOCS);
if (expectedScoresList.size() == 1) {
assertEquals(expectedScoresList.getFirst(), actualScores);
if (expectedScoresListPruned.size() == 1) {
assertEquals(expectedScoresListPruned.getFirst(), actualScores);
} else {
assertEqualsToOneOf(expectedScoresList, actualScores);
assertEqualsToOneOf(expectedScoresListPruned, actualScores);
}
}
@ -356,7 +373,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
continue;
}
assertScores(ss.bulkScorer(), searchScoreMode, null, expectedScores);
assertScores(reader.maxDoc(), ss.bulkScorer(), searchScoreMode, null, expectedScores);
}
}
}
@ -395,7 +412,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, null, expectedScores);
assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, null, expectedScores);
}
{
@ -418,7 +435,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 6.0f, List.of(expectedScores1, expectedScores2));
assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 6.0f, List.of(expectedScores1, expectedScores2));
}
{
@ -426,7 +443,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 11.0f, expectedScores);
assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 11.0f, expectedScores);
}
}
}
@ -465,7 +482,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, null, expectedScores);
assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, null, expectedScores);
}
{
@ -479,7 +496,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 0.0f, expectedScores);
assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 0.0f, expectedScores);
}
{
@ -487,7 +504,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, Math.nextUp(0f), expectedScores);
assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, Math.nextUp(0f), expectedScores);
}
}
}