diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java index 0d92b324e4b..e68b8b42371 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java @@ -532,8 +532,9 @@ public class ToParentBlockJoinQuery extends Query { return scoringCompleteCheck(max, max); } - BatchAwareLeafCollector wrappedCollector = wrapCollector(collector); - childBulkScorer.score(wrappedCollector, acceptDocs, prevParent + 1, lastParent + 1); + BatchAwareLeafCollector wrappedCollector = wrapCollector(collector, acceptDocs); + // We don't propagate the acceptDocs since only parents are checked for deletion in the wrapped collector + childBulkScorer.score(wrappedCollector, null, prevParent + 1, lastParent + 1); wrappedCollector.endBatch(); return scoringCompleteCheck(lastParent + 1, max); @@ -550,7 +551,7 @@ public class ToParentBlockJoinQuery extends Query { return childBulkScorer.cost(); } - private BatchAwareLeafCollector wrapCollector(LeafCollector collector) { + private BatchAwareLeafCollector wrapCollector(LeafCollector collector, Bits acceptDocs) { return new BatchAwareLeafCollector(collector) { private final Score currentParentScore = new Score(scoreMode); private int currentParent = -1; @@ -581,7 +582,7 @@ public class ToParentBlockJoinQuery extends Query { public void collect(int doc) throws IOException { if (doc > currentParent) { // Emit the current parent and setup scoring for the next parent - if (currentParent >= 0) { + if (currentParent >= 0 && (acceptDocs == null || acceptDocs.get(currentParent))) { in.collect(currentParent); } @@ -602,7 +603,7 @@ public class ToParentBlockJoinQuery extends Query { @Override public void endBatch() throws IOException { - if (currentParent >= 0) { + if (currentParent >= 0 && (acceptDocs == null || acceptDocs.get(currentParent))) { in.collect(currentParent); } } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoinBulkScorer.java b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoinBulkScorer.java index b9580331347..9b1a9ee2785 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoinBulkScorer.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoinBulkScorer.java @@ -49,6 +49,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.FixedBitSet; public class TestBlockJoinBulkScorer extends LuceneTestCase { private static final String TYPE_FIELD_NAME = "type"; @@ -256,21 +257,36 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { } private static void assertScores( + int maxDoc, BulkScorer bulkScorer, org.apache.lucene.search.ScoreMode scoreMode, Float minScore, Map expectedScores) throws IOException { - assertScores(bulkScorer, scoreMode, minScore, List.of(expectedScores)); + assertScores(maxDoc, bulkScorer, scoreMode, minScore, List.of(expectedScores)); } private static void assertScores( + int maxDoc, BulkScorer bulkScorer, org.apache.lucene.search.ScoreMode scoreMode, Float minScore, List> expectedScoresList) throws IOException { Map actualScores = new HashMap<>(); + FixedBitSet acceptDocs = new FixedBitSet(maxDoc); + List> expectedScoresListPruned = new ArrayList<>(); + for (var map : expectedScoresList) { + Map newMap = new HashMap<>(); + for (var entry : map.entrySet()) { + if (usually(random())) { + acceptDocs.set(entry.getKey()); + newMap.put(entry.getKey(), entry.getValue()); + } + } + expectedScoresListPruned.add(newMap); + } + bulkScorer.score( new LeafCollector() { private Scorable scorer; @@ -286,18 +302,19 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { @Override public void collect(int doc) throws IOException { + assertTrue(acceptDocs.get(doc)); assertNotNull(scorer); actualScores.put(doc, scoreMode.needsScores() ? scorer.score() : 0); } }, - null, + acceptDocs, 0, NO_MORE_DOCS); - if (expectedScoresList.size() == 1) { - assertEquals(expectedScoresList.getFirst(), actualScores); + if (expectedScoresListPruned.size() == 1) { + assertEquals(expectedScoresListPruned.getFirst(), actualScores); } else { - assertEqualsToOneOf(expectedScoresList, actualScores); + assertEqualsToOneOf(expectedScoresListPruned, actualScores); } } @@ -356,7 +373,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { continue; } - assertScores(ss.bulkScorer(), searchScoreMode, null, expectedScores); + assertScores(reader.maxDoc(), ss.bulkScorer(), searchScoreMode, null, expectedScores); } } } @@ -395,7 +412,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ss.setTopLevelScoringClause(); - assertScores(ss.bulkScorer(), scoreMode, null, expectedScores); + assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, null, expectedScores); } { @@ -418,7 +435,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ss.setTopLevelScoringClause(); - assertScores(ss.bulkScorer(), scoreMode, 6.0f, List.of(expectedScores1, expectedScores2)); + assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 6.0f, List.of(expectedScores1, expectedScores2)); } { @@ -426,7 +443,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ss.setTopLevelScoringClause(); - assertScores(ss.bulkScorer(), scoreMode, 11.0f, expectedScores); + assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 11.0f, expectedScores); } } } @@ -465,7 +482,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ss.setTopLevelScoringClause(); - assertScores(ss.bulkScorer(), scoreMode, null, expectedScores); + assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, null, expectedScores); } { @@ -479,7 +496,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ss.setTopLevelScoringClause(); - assertScores(ss.bulkScorer(), scoreMode, 0.0f, expectedScores); + assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, 0.0f, expectedScores); } { @@ -487,7 +504,7 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase { ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0)); ss.setTopLevelScoringClause(); - assertScores(ss.bulkScorer(), scoreMode, Math.nextUp(0f), expectedScores); + assertScores(reader.maxDoc(), ss.bulkScorer(), scoreMode, Math.nextUp(0f), expectedScores); } } }