diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java index 780e854033a..30bdabbfd44 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java @@ -31,6 +31,7 @@ import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.CollectionUtil; /** A conjunction of DocIdSetIterators. + * Requires that all of its sub-iterators must be on the same document all the time. * This iterates over the doc ids that are present in each given DocIdSetIterator. *
Public only for use in {@link org.apache.lucene.search.spans}. * @lucene.internal @@ -140,6 +141,15 @@ public final class ConjunctionDISI extends DocIdSetIterator { private static DocIdSetIterator createConjunction( List allIterators, List twoPhaseIterators) { + + // check that all sub-iterators are on the same doc ID + int curDoc = allIterators.size() > 0 ? allIterators.get(0).docID() : twoPhaseIterators.get(0).approximation.docID(); + boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc); + iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc); + if (iteratorsOnTheSameDoc == false) { + throw new IllegalArgumentException("Sub-iterators of ConjunctionDISI are not on the same document!"); + } + long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong(); List bitSetIterators = new ArrayList<>(); List iterators = new ArrayList<>(); @@ -177,6 +187,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { private ConjunctionDISI(List iterators) { assert iterators.size() >= 2; + // Sort the array the first time to allow the least frequent DocsEnum to // lead the matching. CollectionUtil.timSort(iterators, new Comparator() { @@ -227,6 +238,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { @Override public int advance(int target) throws IOException { + assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not one the same document!"; return doNext(lead1.advance(target)); } @@ -237,6 +249,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { @Override public int nextDoc() throws IOException { + assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!"; return doNext(lead1.nextDoc()); } @@ -245,6 +258,16 @@ public final class ConjunctionDISI extends DocIdSetIterator { return lead1.cost(); // overestimate } + // Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise + private boolean assertItersOnSameDoc() { + int curDoc = lead1.docID(); + boolean iteratorsOnTheSameDoc = (lead2.docID() == curDoc); + for (int i = 0; (i < others.length && iteratorsOnTheSameDoc); i++) { + iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (others[i].docID() == curDoc); + } + return iteratorsOnTheSameDoc; + } + /** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */ private static class BitSetConjunctionDISI extends DocIdSetIterator { @@ -256,6 +279,7 @@ public final class ConjunctionDISI extends DocIdSetIterator { BitSetConjunctionDISI(DocIdSetIterator lead, Collection bitSetIterators) { this.lead = lead; assert bitSetIterators.size() > 0; + this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]); // Put the least costly iterators first so that we exit as soon as possible ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost())); @@ -276,11 +300,13 @@ public final class ConjunctionDISI extends DocIdSetIterator { @Override public int nextDoc() throws IOException { + assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!"; return doNext(lead.nextDoc()); } @Override public int advance(int target) throws IOException { + assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!"; return doNext(lead.advance(target)); } @@ -306,6 +332,16 @@ public final class ConjunctionDISI extends DocIdSetIterator { return lead.cost(); } + // Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise + private boolean assertItersOnSameDoc() { + int curDoc = lead.docID(); + boolean iteratorsOnTheSameDoc = true; + for (int i = 0; (i < bitSetIterators.length && iteratorsOnTheSameDoc); i++) { + iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (bitSetIterators[i].docID() == curDoc); + } + return iteratorsOnTheSameDoc; + } + } /** diff --git a/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java b/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java index e729ed61872..43cfbe68fe8 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestConjunctionDISI.java @@ -41,7 +41,7 @@ public class TestConjunctionDISI extends LuceneTestCase { return new TwoPhaseIterator(approximation) { @Override - public boolean matches() throws IOException { + public boolean matches() { return confirmed.get(approximation.docID()); } @@ -391,4 +391,21 @@ public class TestConjunctionDISI extends LuceneTestCase { public void testCollapseSubConjunctionScorers() throws IOException { testCollapseSubConjunctions(true); } + + public void testIllegalAdvancementOfSubIteratorsTripsAssertion() throws IOException { + assumeTrue("Assertions must be enabled for this test!", LuceneTestCase.assertsAreEnabled); + int maxDoc = 100; + final int numIterators = TestUtil.nextInt(random(), 2, 5); + FixedBitSet set = randomSet(maxDoc); + + DocIdSetIterator[] iterators = new DocIdSetIterator[numIterators]; + for (int i = 0; i < iterators.length; ++i) { + iterators[i] = new BitDocIdSet(set).iterator(); + } + final DocIdSetIterator conjunction = ConjunctionDISI.intersectIterators(Arrays.asList(iterators)); + int idx = TestUtil.nextInt(random() , 0, iterators.length-1); + iterators[idx].nextDoc(); // illegally advancing one of the sub-iterators outside of the conjunction iterator + AssertionError ex = expectThrows(AssertionError.class, () -> conjunction.nextDoc()); + assertEquals("Sub-iterators of ConjunctionDISI are not on the same document!", ex.getMessage()); + } }