Disjunction as CompetitiveIterator for numeric dynamic pruning (#13221)

// nightly-benchmarks-results-changed //
This commit is contained in:
gf2121 2024-05-20 15:00:09 +08:00 committed by GitHub
parent 7db9c8c9bd
commit 1ee4f8a111
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 392 additions and 107 deletions

View File

@ -18,6 +18,9 @@
package org.apache.lucene.search.comparators;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.function.Consumer;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
@ -29,7 +32,12 @@ import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IntArrayDocIdSet;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.LSBRadixSorter;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.packed.PackedInts;
/**
* Abstract numeric comparator for comparing numeric values. This comparator provides a skipping
@ -42,9 +50,6 @@ import org.apache.lucene.util.DocIdSetBuilder;
*/
public abstract class NumericComparator<T extends Number> extends FieldComparator<T> {
// MIN_SKIP_INTERVAL and MAX_SKIP_INTERVAL both should be powers of 2
private static final int MIN_SKIP_INTERVAL = 32;
private static final int MAX_SKIP_INTERVAL = 8192;
protected final T missingValue;
private final long missingValueAsLong;
protected final String field;
@ -92,10 +97,10 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
/** Leaf comparator for {@link NumericComparator} that provides skipping functionality */
public abstract class NumericLeafComparator implements LeafFieldComparator {
private static final long MAX_DISJUNCTION_CLAUSE = 128;
private final LeafReaderContext context;
protected final NumericDocValues docValues;
private final PointValues pointValues;
private final PointValues.PointTree pointTree;
// if skipping functionality should be enabled on this segment
private final boolean enableSkipping;
private final int maxDoc;
@ -105,14 +110,11 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
private long minValueAsLong = Long.MIN_VALUE;
private long maxValueAsLong = Long.MAX_VALUE;
private Long thresholdAsLong;
private DocIdSetIterator competitiveIterator;
private long iteratorCost = -1;
private long leadCost = -1;
private int maxDocVisited = -1;
private int updateCounter = 0;
private int currentSkipInterval = MIN_SKIP_INTERVAL;
// helps to be conservative about increasing the sampling interval
private int tryUpdateFailCount = 0;
public NumericLeafComparator(LeafReaderContext context) throws IOException {
this.context = context;
@ -139,7 +141,6 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
+ " expected "
+ bytesCount);
}
this.pointTree = pointValues.getPointTree();
this.enableSkipping = true; // skipping is enabled when points are available
this.maxDoc = context.reader().maxDoc();
this.competitiveIterator = DocIdSetIterator.all(maxDoc);
@ -147,7 +148,6 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
encodeTop();
}
} else {
this.pointTree = null;
this.enableSkipping = false;
this.maxDoc = 0;
}
@ -183,12 +183,12 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
@Override
public void setScorer(Scorable scorer) throws IOException {
if (iteratorCost == -1) {
if (leadCost == -1) {
if (scorer instanceof Scorer) {
iteratorCost =
leadCost =
((Scorer) scorer).iterator().cost(); // starting iterator cost is the scorer's cost
} else {
iteratorCost = maxDoc;
leadCost = maxDoc;
}
updateCompetitiveIterator(); // update an iterator when we have a new segment
}
@ -207,102 +207,91 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
|| hitsThresholdReached == false
|| (leafTopSet == false && queueFull == false)) return;
// if some documents have missing points, check that missing values prohibits optimization
if ((pointValues.getDocCount() < maxDoc) && isMissingValueCompetitive()) {
boolean dense = pointValues.getDocCount() == maxDoc;
if (dense == false && isMissingValueCompetitive()) {
return; // we can't filter out documents, as documents with missing values are competitive
}
updateCounter++;
// Start sampling if we get called too much
if (updateCounter > 256
&& (updateCounter & (currentSkipInterval - 1)) != currentSkipInterval - 1) {
return;
}
if (queueFull) {
encodeBottom();
}
DocIdSetBuilder result = new DocIdSetBuilder(maxDoc);
PointValues.IntersectVisitor visitor =
new PointValues.IntersectVisitor() {
DocIdSetBuilder.BulkAdder adder;
@Override
public void grow(int count) {
adder = result.grow(count);
}
@Override
public void visit(int docID) {
if (docID <= maxDocVisited) {
return; // Already visited or skipped
}
adder.add(docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
if (docID <= maxDocVisited) {
return; // already visited or skipped
}
long l = sortableBytesToLong(packedValue);
if (l >= minValueAsLong && l <= maxValueAsLong) {
adder.add(docID); // doc is competitive
}
}
@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
long min = sortableBytesToLong(minPackedValue);
long max = sortableBytesToLong(maxPackedValue);
if (min > maxValueAsLong || max < minValueAsLong) {
// 1. cmp ==0 and pruning==Pruning.GREATER_THAN_OR_EQUAL_TO : if the sort is
// ascending then maxValueAsLong is bottom's next less value, so it is competitive
// 2. cmp ==0 and pruning==Pruning.GREATER_THAN: maxValueAsLong equals to
// bottom, but there are multiple comparators, so it could be competitive
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
if (min < minValueAsLong || max > maxValueAsLong) {
return PointValues.Relation.CELL_CROSSES_QUERY;
}
return PointValues.Relation.CELL_INSIDE_QUERY;
}
};
final long threshold = iteratorCost >>> 3;
if (PointValues.isEstimatedPointCountGreaterThanOrEqualTo(visitor, pointTree, threshold)) {
// the new range is not selective enough to be worth materializing, it doesn't reduce number
// of docs at least 8x
updateSkipInterval(false);
if (pointValues.getDocCount() < iteratorCost) {
// Use the set of doc with values to help drive iteration
competitiveIterator = getNumericDocValues(context, field);
iteratorCost = pointValues.getDocCount();
if (competitiveIterator instanceof CompetitiveIterator iter) {
if (queueFull) {
encodeBottom();
}
// CompetitiveIterator already built, try to reduce clause.
tryReduceDisjunctionClause(iter);
return;
}
pointValues.intersect(visitor);
competitiveIterator = result.build().iterator();
iteratorCost = competitiveIterator.cost();
updateSkipInterval(true);
if (thresholdAsLong == null) {
if (dense == false) {
competitiveIterator = getNumericDocValues(context, field);
leadCost = Math.min(leadCost, competitiveIterator.cost());
}
long threshold = Math.min(leadCost >> 3, maxDoc >> 5);
thresholdAsLong = intersectThresholdValue(threshold);
}
if ((reverse == false && bottomAsComparableLong() <= thresholdAsLong)
|| (reverse && bottomAsComparableLong() >= thresholdAsLong)) {
if (queueFull) {
encodeBottom();
}
DisjunctionBuildVisitor visitor = new DisjunctionBuildVisitor();
competitiveIterator = visitor.generateCompetitiveIterator();
}
}
private void updateSkipInterval(boolean success) {
if (updateCounter > 256) {
if (success) {
currentSkipInterval = Math.max(currentSkipInterval / 2, MIN_SKIP_INTERVAL);
tryUpdateFailCount = 0;
} else {
if (tryUpdateFailCount >= 3) {
currentSkipInterval = Math.min(currentSkipInterval * 2, MAX_SKIP_INTERVAL);
tryUpdateFailCount = 0;
} else {
tryUpdateFailCount++;
}
}
private void tryReduceDisjunctionClause(CompetitiveIterator iter) {
int originalSize = iter.disis.size();
while (iter.disis.isEmpty() == false
&& (iter.disis.getFirst().mostCompetitiveValue > maxValueAsLong
|| iter.disis.getFirst().mostCompetitiveValue < minValueAsLong)) {
iter.disis.removeFirst();
}
if (originalSize != iter.disis.size()) {
iter.disjunction.clear();
iter.disjunction.addAll(iter.disis);
}
}
/** Find out the value that threshold docs away from topValue/infinite. */
private long intersectThresholdValue(long threshold) throws IOException {
long thresholdValuePos;
if (leafTopSet) {
long topValue = topAsComparableLong();
PointValues.IntersectVisitor visitor = new RangeVisitor(Long.MIN_VALUE, topValue, -1);
long topValuePos = pointValues.estimatePointCount(visitor);
thresholdValuePos = reverse == false ? topValuePos + threshold : topValuePos - threshold;
} else {
thresholdValuePos = reverse == false ? threshold : pointValues.size() - threshold;
}
if (thresholdValuePos <= 0) {
return sortableBytesToLong(pointValues.getMinPackedValue());
} else if (thresholdValuePos >= pointValues.size()) {
return sortableBytesToLong(pointValues.getMaxPackedValue());
} else {
return intersectValueByPos(pointValues.getPointTree(), thresholdValuePos);
}
}
/** Get the point value by a left-to-right position. */
private long intersectValueByPos(PointValues.PointTree pointTree, long pos) throws IOException {
assert pos > 0 : pos;
while (pointTree.size() < pos) {
pos -= pointTree.size();
pointTree.moveToSibling();
}
if (pointTree.size() == pos) {
return sortableBytesToLong(pointTree.getMaxPackedValue());
} else if (pos == 0) {
return sortableBytesToLong(pointTree.getMinPackedValue());
} else if (pointTree.moveToChild()) {
return intersectValueByPos(pointTree, pos);
} else {
return reverse == false
? sortableBytesToLong(pointTree.getMaxPackedValue())
: sortableBytesToLong(pointTree.getMinPackedValue());
}
}
@ -405,5 +394,276 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
protected abstract long bottomAsComparableLong();
protected abstract long topAsComparableLong();
class DisjunctionBuildVisitor extends RangeVisitor {
final Deque<DisiAndMostCompetitiveValue> disis = new ArrayDeque<>();
// most competitive entry stored last.
final Consumer<DisiAndMostCompetitiveValue> adder =
reverse == false ? disis::addFirst : disis::addLast;
final int minBlockLength = minBlockLength();
final LSBRadixSorter sorter = new LSBRadixSorter();
int[] docs = IntsRef.EMPTY_INTS;
int index = 0;
int blockMaxDoc = -1;
boolean docsInOrder = true;
long blockMinValue = Long.MAX_VALUE;
long blockMaxValue = Long.MIN_VALUE;
private DisjunctionBuildVisitor() {
super(minValueAsLong, maxValueAsLong, maxDocVisited);
}
@Override
public void grow(int count) {
docs = ArrayUtil.grow(docs, index + count + 1);
}
@Override
protected void consumeDoc(int doc) {
docs[index++] = doc;
if (doc >= blockMaxDoc) {
blockMaxDoc = doc;
} else {
docsInOrder = false;
}
}
void intersectLeaves(PointValues.PointTree pointTree) throws IOException {
PointValues.Relation r =
compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
switch (r) {
case CELL_INSIDE_QUERY, CELL_CROSSES_QUERY -> {
if (pointTree.moveToChild()) {
do {
intersectLeaves(pointTree);
} while (pointTree.moveToSibling());
pointTree.moveToParent();
} else {
if (r == PointValues.Relation.CELL_CROSSES_QUERY) {
pointTree.visitDocValues(this);
} else {
pointTree.visitDocIDs(this);
}
updateMinMax(
sortableBytesToLong(pointTree.getMinPackedValue()),
sortableBytesToLong(pointTree.getMaxPackedValue()));
}
}
case CELL_OUTSIDE_QUERY -> {}
default -> throw new IllegalStateException("unreachable code");
}
}
void updateMinMax(long leafMinValue, long leafMaxValue) throws IOException {
this.blockMinValue = Math.min(blockMinValue, leafMinValue);
this.blockMaxValue = Math.max(blockMaxValue, leafMaxValue);
if (index >= minBlockLength) {
update();
this.blockMinValue = Long.MAX_VALUE;
this.blockMaxValue = Long.MIN_VALUE;
}
}
void update() throws IOException {
if (blockMinValue > blockMaxValue) {
return;
}
long mostCompetitiveValue =
reverse == false
? Math.max(blockMinValue, minValueAsLong)
: Math.min(blockMaxValue, maxValueAsLong);
if (docsInOrder == false) {
sorter.sort(PackedInts.bitsRequired(blockMaxDoc), docs, index);
}
docs[index] = DocIdSetIterator.NO_MORE_DOCS;
DocIdSetIterator iter = new IntArrayDocIdSet(docs, index).iterator();
adder.accept(new DisiAndMostCompetitiveValue(iter, mostCompetitiveValue));
docs = IntsRef.EMPTY_INTS;
index = 0;
blockMaxDoc = -1;
docsInOrder = true;
}
DocIdSetIterator generateCompetitiveIterator() throws IOException {
intersectLeaves(pointValues.getPointTree());
update();
if (disis.isEmpty()) {
return DocIdSetIterator.empty();
}
assert assertMostCompetitiveValuesSorted(disis);
PriorityQueue<DisiAndMostCompetitiveValue> disjunction =
new PriorityQueue<>(disis.size()) {
@Override
protected boolean lessThan(
DisiAndMostCompetitiveValue a, DisiAndMostCompetitiveValue b) {
return a.disi.docID() < b.disi.docID();
}
};
disjunction.addAll(disis);
return new CompetitiveIterator(maxDoc, disis, disjunction);
}
/**
* Used for assert. When reverse is false, smaller values are more competitive, so
* mostCompetitiveValues should be in desc order.
*/
private boolean assertMostCompetitiveValuesSorted(Deque<DisiAndMostCompetitiveValue> deque) {
long lastValue = reverse == false ? Long.MAX_VALUE : Long.MIN_VALUE;
for (DisiAndMostCompetitiveValue value : deque) {
if (reverse == false) {
assert value.mostCompetitiveValue <= lastValue
: deque.stream().map(d -> d.mostCompetitiveValue).toList().toString();
} else {
assert value.mostCompetitiveValue >= lastValue
: deque.stream().map(d -> d.mostCompetitiveValue).toList().toString();
}
lastValue = value.mostCompetitiveValue;
}
return true;
}
private int minBlockLength() {
// bottom value can be much more competitive than thresholdAsLong, recompute the cost.
long cost =
pointValues.estimatePointCount(new RangeVisitor(minValueAsLong, maxValueAsLong, -1));
long disjunctionClause = Math.min(MAX_DISJUNCTION_CLAUSE, cost / 512 + 1);
return Math.toIntExact(cost / disjunctionClause);
}
}
}
private class RangeVisitor implements PointValues.IntersectVisitor {
private final long minInclusive;
private final long maxInclusive;
private final int docLowerBound;
private RangeVisitor(long minInclusive, long maxInclusive, int docLowerBound) {
this.minInclusive = minInclusive;
this.maxInclusive = maxInclusive;
this.docLowerBound = docLowerBound;
}
@Override
public void visit(int docID) throws IOException {
if (docID <= docLowerBound) {
return; // Already visited or skipped
}
consumeDoc(docID);
}
@Override
public void visit(int docID, byte[] packedValue) throws IOException {
if (docID <= docLowerBound) {
return; // already visited or skipped
}
long l = sortableBytesToLong(packedValue);
if (l >= minInclusive && l <= maxInclusive) {
consumeDoc(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
long l = sortableBytesToLong(packedValue);
if (l >= minInclusive && l <= maxInclusive) {
int doc = docLowerBound >= 0 ? iterator.advance(docLowerBound) : iterator.nextDoc();
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
consumeDoc(doc);
doc = iterator.nextDoc();
}
}
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
int doc = docLowerBound >= 0 ? iterator.advance(docLowerBound) : iterator.nextDoc();
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
consumeDoc(doc);
doc = iterator.nextDoc();
}
}
@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
long min = sortableBytesToLong(minPackedValue);
long max = sortableBytesToLong(maxPackedValue);
if (min > maxInclusive || max < minInclusive) {
// 1. cmp ==0 and pruning==Pruning.GREATER_THAN_OR_EQUAL_TO : if the sort is
// ascending then maxValueAsLong is bottom's next less value, so it is competitive
// 2. cmp ==0 and pruning==Pruning.GREATER_THAN: maxValueAsLong equals to
// bottom, but there are multiple comparators, so it could be competitive
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
if (min < minInclusive || max > maxInclusive) {
return PointValues.Relation.CELL_CROSSES_QUERY;
}
return PointValues.Relation.CELL_INSIDE_QUERY;
}
void consumeDoc(int doc) {
throw new UnsupportedOperationException();
}
}
private record DisiAndMostCompetitiveValue(DocIdSetIterator disi, long mostCompetitiveValue) {}
private static class CompetitiveIterator extends DocIdSetIterator {
private final int maxDoc;
private int doc = -1;
private final Deque<DisiAndMostCompetitiveValue> disis;
private final PriorityQueue<DisiAndMostCompetitiveValue> disjunction;
CompetitiveIterator(
int maxDoc,
Deque<DisiAndMostCompetitiveValue> disis,
PriorityQueue<DisiAndMostCompetitiveValue> disjunction) {
this.maxDoc = maxDoc;
this.disis = disis;
this.disjunction = disjunction;
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(docID() + 1);
}
@Override
public int advance(int target) throws IOException {
if (target >= maxDoc) {
return doc = NO_MORE_DOCS;
} else {
DisiAndMostCompetitiveValue top = disjunction.top();
if (top == null) {
// priority queue is empty, none of the remaining documents are competitive
return doc = NO_MORE_DOCS;
}
while (top.disi.docID() < target) {
top.disi.advance(target);
top = disjunction.updateTop();
}
return doc = top.disi.docID();
}
}
@Override
public long cost() {
return maxDoc;
}
}
}

View File

@ -21,7 +21,12 @@ import java.util.Arrays;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
final class IntArrayDocIdSet extends DocIdSet {
/**
* A doc id set based on sorted int array.
*
* @lucene.internal
*/
public final class IntArrayDocIdSet extends DocIdSet {
private static final long BASE_RAM_BYTES_USED =
RamUsageEstimator.shallowSizeOfInstance(IntArrayDocIdSet.class);
@ -29,12 +34,32 @@ final class IntArrayDocIdSet extends DocIdSet {
private final int[] docs;
private final int length;
IntArrayDocIdSet(int[] docs, int length) {
if (docs[length] != DocIdSetIterator.NO_MORE_DOCS) {
/**
* Build an IntArrayDocIdSet by an int array and len.
*
* @param docs A docs array whose length need to be greater than the param len. It needs to be
* sorted from 0(inclusive) to the len(exclusive), and the len-th doc in docs need to be
* {@link DocIdSetIterator#NO_MORE_DOCS}.
* @param len The valid docs length in array.
*/
public IntArrayDocIdSet(int[] docs, int len) {
if (docs[len] != DocIdSetIterator.NO_MORE_DOCS) {
throw new IllegalArgumentException();
}
assert assertArraySorted(docs, len)
: "IntArrayDocIdSet need docs to be sorted"
+ Arrays.toString(ArrayUtil.copyOfSubArray(docs, 0, len));
this.docs = docs;
this.length = length;
this.length = len;
}
private static boolean assertArraySorted(int[] docs, int length) {
for (int i = 1; i < length; i++) {
if (docs[i] < docs[i - 1]) {
return false;
}
}
return true;
}
@Override