BlockJoinBulkScorer must check for parent deletions (not children)

This change ensures that BlockJoinBulkScorer verifies deletions at the parent level.
This commit is contained in:
Jim Ferenczi 2024-12-15 09:06:53 +00:00
parent a8d8d6b3d9
commit 4e40487cb8
2 changed files with 35 additions and 17 deletions

View File

@ -532,8 +532,9 @@ public class ToParentBlockJoinQuery extends Query {
return scoringCompleteCheck(max, max); return scoringCompleteCheck(max, max);
} }
BatchAwareLeafCollector wrappedCollector = wrapCollector(collector); BatchAwareLeafCollector wrappedCollector = wrapCollector(collector, acceptDocs);
childBulkScorer.score(wrappedCollector, acceptDocs, prevParent + 1, lastParent + 1); // We don't propagate the acceptDocs since only parents are checked for deletion in the wrapped collector
childBulkScorer.score(wrappedCollector, null, prevParent + 1, lastParent + 1);
wrappedCollector.endBatch(); wrappedCollector.endBatch();
return scoringCompleteCheck(lastParent + 1, max); return scoringCompleteCheck(lastParent + 1, max);
@ -550,7 +551,7 @@ public class ToParentBlockJoinQuery extends Query {
return childBulkScorer.cost(); return childBulkScorer.cost();
} }
private BatchAwareLeafCollector wrapCollector(LeafCollector collector) { private BatchAwareLeafCollector wrapCollector(LeafCollector collector, Bits acceptDocs) {
return new BatchAwareLeafCollector(collector) { return new BatchAwareLeafCollector(collector) {
private final Score currentParentScore = new Score(scoreMode); private final Score currentParentScore = new Score(scoreMode);
private int currentParent = -1; private int currentParent = -1;
@ -581,7 +582,7 @@ public class ToParentBlockJoinQuery extends Query {
public void collect(int doc) throws IOException { public void collect(int doc) throws IOException {
if (doc > currentParent) { if (doc > currentParent) {
// Emit the current parent and setup scoring for the next parent // Emit the current parent and setup scoring for the next parent
if (currentParent >= 0) { if (currentParent >= 0 && (acceptDocs == null || acceptDocs.get(currentParent))) {
in.collect(currentParent); in.collect(currentParent);
} }
@ -602,7 +603,7 @@ public class ToParentBlockJoinQuery extends Query {
@Override @Override
public void endBatch() throws IOException { public void endBatch() throws IOException {
if (currentParent >= 0) { if (currentParent >= 0 && (acceptDocs == null || acceptDocs.get(currentParent))) {
in.collect(currentParent); in.collect(currentParent);
} }
} }

View File

@ -49,6 +49,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.FixedBitSet;
public class TestBlockJoinBulkScorer extends LuceneTestCase { public class TestBlockJoinBulkScorer extends LuceneTestCase {
private static final String TYPE_FIELD_NAME = "type"; private static final String TYPE_FIELD_NAME = "type";
@ -256,21 +257,36 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
} }
private static void assertScores( private static void assertScores(
int maxDoc,
BulkScorer bulkScorer, BulkScorer bulkScorer,
org.apache.lucene.search.ScoreMode scoreMode, org.apache.lucene.search.ScoreMode scoreMode,
Float minScore, Float minScore,
Map<Integer, Float> expectedScores) Map<Integer, Float> expectedScores)
throws IOException { throws IOException {
assertScores(bulkScorer, scoreMode, minScore, List.of(expectedScores)); assertScores(maxDoc, bulkScorer, scoreMode, minScore, List.of(expectedScores));
} }
private static void assertScores( private static void assertScores(
int maxDoc,
BulkScorer bulkScorer, BulkScorer bulkScorer,
org.apache.lucene.search.ScoreMode scoreMode, org.apache.lucene.search.ScoreMode scoreMode,
Float minScore, Float minScore,
List<Map<Integer, Float>> expectedScoresList) List<Map<Integer, Float>> expectedScoresList)
throws IOException { throws IOException {
Map<Integer, Float> actualScores = new HashMap<>(); Map<Integer, Float> actualScores = new HashMap<>();
FixedBitSet acceptDocs = new FixedBitSet(maxDoc);
List<Map<Integer, Float>> expectedScoresListPruned = new ArrayList<>();
for (var map : expectedScoresList) {
Map<Integer, Float> newMap = new HashMap<>();
for (var entry : map.entrySet()) {
if (usually(random())) {
acceptDocs.set(entry.getKey());
newMap.put(entry.getKey(), entry.getValue());
}
}
expectedScoresListPruned.add(newMap);
}
bulkScorer.score( bulkScorer.score(
new LeafCollector() { new LeafCollector() {
private Scorable scorer; private Scorable scorer;
@ -286,18 +302,19 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
@Override @Override
public void collect(int doc) throws IOException { public void collect(int doc) throws IOException {
assertTrue(acceptDocs.get(doc));
assertNotNull(scorer); assertNotNull(scorer);
actualScores.put(doc, scoreMode.needsScores() ? scorer.score() : 0); actualScores.put(doc, scoreMode.needsScores() ? scorer.score() : 0);
} }
}, },
null, acceptDocs,
0, 0,
NO_MORE_DOCS); NO_MORE_DOCS);
if (expectedScoresList.size() == 1) { if (expectedScoresListPruned.size() == 1) {
assertEquals(expectedScoresList.getFirst(), actualScores); assertEquals(expectedScoresListPruned.getFirst(), actualScores);
} else { } else {
assertEqualsToOneOf(expectedScoresList, actualScores); assertEqualsToOneOf(expectedScoresListPruned, actualScores);
} }
} }
@ -356,7 +373,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
continue; continue;
} }
assertScores(ss.bulkScorer(), searchScoreMode, null, expectedScores); assertScores(reader.maxDoc(), ss.bulkScorer(), searchScoreMode, null, expectedScores);
} }
} }
} }
@ -395,7 +412,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause(); ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, null, expectedScores); assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, null, expectedScores);
} }
{ {
@ -418,7 +435,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause(); ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 6.0f, List.of(expectedScores1, expectedScores2)); assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 6.0f, List.of(expectedScores1, expectedScores2));
} }
{ {
@ -426,7 +443,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause(); ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 11.0f, expectedScores); assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 11.0f, expectedScores);
} }
} }
} }
@ -465,7 +482,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause(); ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, null, expectedScores); assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, null, expectedScores);
} }
{ {
@ -479,7 +496,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause(); ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 0.0f, expectedScores); assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 0.0f, expectedScores);
} }
{ {
@ -487,7 +504,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause(); ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, Math.nextUp(0f), expectedScores); assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, Math.nextUp(0f), expectedScores);
} }
} }
} }