LUCENE-9541 ConjunctionDISI sub-iterators check (#1937)

* LUCENE-9541 ConjunctionDISI sub-iterators check

Ensure sub-iterators of a conjunction iterator are on the same doc.
This commit is contained in:
Mayya Sharipova 2020-10-05 09:38:17 -04:00
parent 874c446ab9
commit 6b8288445f
2 changed files with 54 additions and 1 deletions

View File

@ -31,6 +31,7 @@ import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.CollectionUtil;
/** A conjunction of DocIdSetIterators. /** A conjunction of DocIdSetIterators.
* Requires that all of its sub-iterators must be on the same document all the time.
* This iterates over the doc ids that are present in each given DocIdSetIterator. * This iterates over the doc ids that are present in each given DocIdSetIterator.
* <br>Public only for use in {@link org.apache.lucene.search.spans}. * <br>Public only for use in {@link org.apache.lucene.search.spans}.
* @lucene.internal * @lucene.internal
@ -140,6 +141,15 @@ public final class ConjunctionDISI extends DocIdSetIterator {
private static DocIdSetIterator createConjunction( private static DocIdSetIterator createConjunction(
List<DocIdSetIterator> allIterators, List<DocIdSetIterator> allIterators,
List<TwoPhaseIterator> twoPhaseIterators) { List<TwoPhaseIterator> twoPhaseIterators) {
// check that all sub-iterators are on the same doc ID
int curDoc = allIterators.size() > 0 ? allIterators.get(0).docID() : twoPhaseIterators.get(0).approximation.docID();
boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc);
iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc);
if (iteratorsOnTheSameDoc == false) {
throw new IllegalArgumentException("Sub-iterators of ConjunctionDISI are not on the same document!");
}
long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong(); long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong();
List<BitSetIterator> bitSetIterators = new ArrayList<>(); List<BitSetIterator> bitSetIterators = new ArrayList<>();
List<DocIdSetIterator> iterators = new ArrayList<>(); List<DocIdSetIterator> iterators = new ArrayList<>();
@ -177,6 +187,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
private ConjunctionDISI(List<? extends DocIdSetIterator> iterators) { private ConjunctionDISI(List<? extends DocIdSetIterator> iterators) {
assert iterators.size() >= 2; assert iterators.size() >= 2;
// Sort the array the first time to allow the least frequent DocsEnum to // Sort the array the first time to allow the least frequent DocsEnum to
// lead the matching. // lead the matching.
CollectionUtil.timSort(iterators, new Comparator<DocIdSetIterator>() { CollectionUtil.timSort(iterators, new Comparator<DocIdSetIterator>() {
@ -227,6 +238,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
@Override @Override
public int advance(int target) throws IOException { public int advance(int target) throws IOException {
assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not one the same document!";
return doNext(lead1.advance(target)); return doNext(lead1.advance(target));
} }
@ -237,6 +249,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
@Override @Override
public int nextDoc() throws IOException { public int nextDoc() throws IOException {
assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!";
return doNext(lead1.nextDoc()); return doNext(lead1.nextDoc());
} }
@ -245,6 +258,16 @@ public final class ConjunctionDISI extends DocIdSetIterator {
return lead1.cost(); // overestimate return lead1.cost(); // overestimate
} }
// Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise
private boolean assertItersOnSameDoc() {
int curDoc = lead1.docID();
boolean iteratorsOnTheSameDoc = (lead2.docID() == curDoc);
for (int i = 0; (i < others.length && iteratorsOnTheSameDoc); i++) {
iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (others[i].docID() == curDoc);
}
return iteratorsOnTheSameDoc;
}
/** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */ /** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */
private static class BitSetConjunctionDISI extends DocIdSetIterator { private static class BitSetConjunctionDISI extends DocIdSetIterator {
@ -256,6 +279,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
BitSetConjunctionDISI(DocIdSetIterator lead, Collection<BitSetIterator> bitSetIterators) { BitSetConjunctionDISI(DocIdSetIterator lead, Collection<BitSetIterator> bitSetIterators) {
this.lead = lead; this.lead = lead;
assert bitSetIterators.size() > 0; assert bitSetIterators.size() > 0;
this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]); this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]);
// Put the least costly iterators first so that we exit as soon as possible // Put the least costly iterators first so that we exit as soon as possible
ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost())); ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost()));
@ -276,11 +300,13 @@ public final class ConjunctionDISI extends DocIdSetIterator {
@Override @Override
public int nextDoc() throws IOException { public int nextDoc() throws IOException {
assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!";
return doNext(lead.nextDoc()); return doNext(lead.nextDoc());
} }
@Override @Override
public int advance(int target) throws IOException { public int advance(int target) throws IOException {
assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!";
return doNext(lead.advance(target)); return doNext(lead.advance(target));
} }
@ -306,6 +332,16 @@ public final class ConjunctionDISI extends DocIdSetIterator {
return lead.cost(); return lead.cost();
} }
// Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise
private boolean assertItersOnSameDoc() {
int curDoc = lead.docID();
boolean iteratorsOnTheSameDoc = true;
for (int i = 0; (i < bitSetIterators.length && iteratorsOnTheSameDoc); i++) {
iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (bitSetIterators[i].docID() == curDoc);
}
return iteratorsOnTheSameDoc;
}
} }
/** /**

View File

@ -41,7 +41,7 @@ public class TestConjunctionDISI extends LuceneTestCase {
return new TwoPhaseIterator(approximation) { return new TwoPhaseIterator(approximation) {
@Override @Override
public boolean matches() throws IOException { public boolean matches() {
return confirmed.get(approximation.docID()); return confirmed.get(approximation.docID());
} }
@ -391,4 +391,21 @@ public class TestConjunctionDISI extends LuceneTestCase {
public void testCollapseSubConjunctionScorers() throws IOException { public void testCollapseSubConjunctionScorers() throws IOException {
testCollapseSubConjunctions(true); testCollapseSubConjunctions(true);
} }
public void testIllegalAdvancementOfSubIteratorsTripsAssertion() throws IOException {
assumeTrue("Assertions must be enabled for this test!", LuceneTestCase.assertsAreEnabled);
int maxDoc = 100;
final int numIterators = TestUtil.nextInt(random(), 2, 5);
FixedBitSet set = randomSet(maxDoc);
DocIdSetIterator[] iterators = new DocIdSetIterator[numIterators];
for (int i = 0; i < iterators.length; ++i) {
iterators[i] = new BitDocIdSet(set).iterator();
}
final DocIdSetIterator conjunction = ConjunctionDISI.intersectIterators(Arrays.asList(iterators));
int idx = TestUtil.nextInt(random() , 0, iterators.length-1);
iterators[idx].nextDoc(); // illegally advancing one of the sub-iterators outside of the conjunction iterator
AssertionError ex = expectThrows(AssertionError.class, () -> conjunction.nextDoc());
assertEquals("Sub-iterators of ConjunctionDISI are not on the same document!", ex.getMessage());
}
} }