From 1ee4f8a1115d1de623f242014681032d87ed2c1e Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Mon, 20 May 2024 15:00:09 +0800 Subject: [PATCH] Disjunction as CompetitiveIterator for numeric dynamic pruning (#13221) // nightly-benchmarks-results-changed // --- .../search/comparators/NumericComparator.java | 466 ++++++++++++++---- .../apache/lucene/util/IntArrayDocIdSet.java | 33 +- 2 files changed, 392 insertions(+), 107 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java b/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java index a0cd748f94c..3d1a84ee645 100644 --- a/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java +++ b/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java @@ -18,6 +18,9 @@ package org.apache.lucene.search.comparators; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.function.Consumer; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -29,7 +32,12 @@ import org.apache.lucene.search.LeafFieldComparator; import org.apache.lucene.search.Pruning; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.Scorer; -import org.apache.lucene.util.DocIdSetBuilder; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.IntArrayDocIdSet; +import org.apache.lucene.util.IntsRef; +import org.apache.lucene.util.LSBRadixSorter; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.packed.PackedInts; /** * Abstract numeric comparator for comparing numeric values. This comparator provides a skipping @@ -42,9 +50,6 @@ import org.apache.lucene.util.DocIdSetBuilder; */ public abstract class NumericComparator extends FieldComparator { - // MIN_SKIP_INTERVAL and MAX_SKIP_INTERVAL both should be powers of 2 - private static final int MIN_SKIP_INTERVAL = 32; - private static final int MAX_SKIP_INTERVAL = 8192; protected final T missingValue; private final long missingValueAsLong; protected final String field; @@ -92,10 +97,10 @@ public abstract class NumericComparator extends FieldComparato /** Leaf comparator for {@link NumericComparator} that provides skipping functionality */ public abstract class NumericLeafComparator implements LeafFieldComparator { + private static final long MAX_DISJUNCTION_CLAUSE = 128; private final LeafReaderContext context; protected final NumericDocValues docValues; private final PointValues pointValues; - private final PointValues.PointTree pointTree; // if skipping functionality should be enabled on this segment private final boolean enableSkipping; private final int maxDoc; @@ -105,14 +110,11 @@ public abstract class NumericComparator extends FieldComparato private long minValueAsLong = Long.MIN_VALUE; private long maxValueAsLong = Long.MAX_VALUE; + private Long thresholdAsLong; private DocIdSetIterator competitiveIterator; - private long iteratorCost = -1; + private long leadCost = -1; private int maxDocVisited = -1; - private int updateCounter = 0; - private int currentSkipInterval = MIN_SKIP_INTERVAL; - // helps to be conservative about increasing the sampling interval - private int tryUpdateFailCount = 0; public NumericLeafComparator(LeafReaderContext context) throws IOException { this.context = context; @@ -139,7 +141,6 @@ public abstract class NumericComparator extends FieldComparato + " expected " + bytesCount); } - this.pointTree = pointValues.getPointTree(); this.enableSkipping = true; // skipping is enabled when points are available this.maxDoc = context.reader().maxDoc(); this.competitiveIterator = DocIdSetIterator.all(maxDoc); @@ -147,7 +148,6 @@ public abstract class NumericComparator extends FieldComparato encodeTop(); } } else { - this.pointTree = null; this.enableSkipping = false; this.maxDoc = 0; } @@ -183,12 +183,12 @@ public abstract class NumericComparator extends FieldComparato @Override public void setScorer(Scorable scorer) throws IOException { - if (iteratorCost == -1) { + if (leadCost == -1) { if (scorer instanceof Scorer) { - iteratorCost = + leadCost = ((Scorer) scorer).iterator().cost(); // starting iterator cost is the scorer's cost } else { - iteratorCost = maxDoc; + leadCost = maxDoc; } updateCompetitiveIterator(); // update an iterator when we have a new segment } @@ -207,102 +207,91 @@ public abstract class NumericComparator extends FieldComparato || hitsThresholdReached == false || (leafTopSet == false && queueFull == false)) return; // if some documents have missing points, check that missing values prohibits optimization - if ((pointValues.getDocCount() < maxDoc) && isMissingValueCompetitive()) { + boolean dense = pointValues.getDocCount() == maxDoc; + if (dense == false && isMissingValueCompetitive()) { return; // we can't filter out documents, as documents with missing values are competitive } - updateCounter++; - // Start sampling if we get called too much - if (updateCounter > 256 - && (updateCounter & (currentSkipInterval - 1)) != currentSkipInterval - 1) { - return; - } - - if (queueFull) { - encodeBottom(); - } - - DocIdSetBuilder result = new DocIdSetBuilder(maxDoc); - PointValues.IntersectVisitor visitor = - new PointValues.IntersectVisitor() { - DocIdSetBuilder.BulkAdder adder; - - @Override - public void grow(int count) { - adder = result.grow(count); - } - - @Override - public void visit(int docID) { - if (docID <= maxDocVisited) { - return; // Already visited or skipped - } - adder.add(docID); - } - - @Override - public void visit(int docID, byte[] packedValue) { - if (docID <= maxDocVisited) { - return; // already visited or skipped - } - long l = sortableBytesToLong(packedValue); - if (l >= minValueAsLong && l <= maxValueAsLong) { - adder.add(docID); // doc is competitive - } - } - - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - long min = sortableBytesToLong(minPackedValue); - long max = sortableBytesToLong(maxPackedValue); - - if (min > maxValueAsLong || max < minValueAsLong) { - // 1. cmp ==0 and pruning==Pruning.GREATER_THAN_OR_EQUAL_TO : if the sort is - // ascending then maxValueAsLong is bottom's next less value, so it is competitive - // 2. cmp ==0 and pruning==Pruning.GREATER_THAN: maxValueAsLong equals to - // bottom, but there are multiple comparators, so it could be competitive - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } - - if (min < minValueAsLong || max > maxValueAsLong) { - return PointValues.Relation.CELL_CROSSES_QUERY; - } - return PointValues.Relation.CELL_INSIDE_QUERY; - } - }; - - final long threshold = iteratorCost >>> 3; - - if (PointValues.isEstimatedPointCountGreaterThanOrEqualTo(visitor, pointTree, threshold)) { - // the new range is not selective enough to be worth materializing, it doesn't reduce number - // of docs at least 8x - updateSkipInterval(false); - if (pointValues.getDocCount() < iteratorCost) { - // Use the set of doc with values to help drive iteration - competitiveIterator = getNumericDocValues(context, field); - iteratorCost = pointValues.getDocCount(); + if (competitiveIterator instanceof CompetitiveIterator iter) { + if (queueFull) { + encodeBottom(); } + // CompetitiveIterator already built, try to reduce clause. + tryReduceDisjunctionClause(iter); return; } - pointValues.intersect(visitor); - competitiveIterator = result.build().iterator(); - iteratorCost = competitiveIterator.cost(); - updateSkipInterval(true); + + if (thresholdAsLong == null) { + if (dense == false) { + competitiveIterator = getNumericDocValues(context, field); + leadCost = Math.min(leadCost, competitiveIterator.cost()); + } + long threshold = Math.min(leadCost >> 3, maxDoc >> 5); + thresholdAsLong = intersectThresholdValue(threshold); + } + + if ((reverse == false && bottomAsComparableLong() <= thresholdAsLong) + || (reverse && bottomAsComparableLong() >= thresholdAsLong)) { + if (queueFull) { + encodeBottom(); + } + DisjunctionBuildVisitor visitor = new DisjunctionBuildVisitor(); + competitiveIterator = visitor.generateCompetitiveIterator(); + } } - private void updateSkipInterval(boolean success) { - if (updateCounter > 256) { - if (success) { - currentSkipInterval = Math.max(currentSkipInterval / 2, MIN_SKIP_INTERVAL); - tryUpdateFailCount = 0; - } else { - if (tryUpdateFailCount >= 3) { - currentSkipInterval = Math.min(currentSkipInterval * 2, MAX_SKIP_INTERVAL); - tryUpdateFailCount = 0; - } else { - tryUpdateFailCount++; - } - } + private void tryReduceDisjunctionClause(CompetitiveIterator iter) { + int originalSize = iter.disis.size(); + + while (iter.disis.isEmpty() == false + && (iter.disis.getFirst().mostCompetitiveValue > maxValueAsLong + || iter.disis.getFirst().mostCompetitiveValue < minValueAsLong)) { + iter.disis.removeFirst(); + } + + if (originalSize != iter.disis.size()) { + iter.disjunction.clear(); + iter.disjunction.addAll(iter.disis); + } + } + + /** Find out the value that threshold docs away from topValue/infinite. */ + private long intersectThresholdValue(long threshold) throws IOException { + long thresholdValuePos; + if (leafTopSet) { + long topValue = topAsComparableLong(); + PointValues.IntersectVisitor visitor = new RangeVisitor(Long.MIN_VALUE, topValue, -1); + long topValuePos = pointValues.estimatePointCount(visitor); + thresholdValuePos = reverse == false ? topValuePos + threshold : topValuePos - threshold; + } else { + thresholdValuePos = reverse == false ? threshold : pointValues.size() - threshold; + } + if (thresholdValuePos <= 0) { + return sortableBytesToLong(pointValues.getMinPackedValue()); + } else if (thresholdValuePos >= pointValues.size()) { + return sortableBytesToLong(pointValues.getMaxPackedValue()); + } else { + return intersectValueByPos(pointValues.getPointTree(), thresholdValuePos); + } + } + + /** Get the point value by a left-to-right position. */ + private long intersectValueByPos(PointValues.PointTree pointTree, long pos) throws IOException { + assert pos > 0 : pos; + while (pointTree.size() < pos) { + pos -= pointTree.size(); + pointTree.moveToSibling(); + } + if (pointTree.size() == pos) { + return sortableBytesToLong(pointTree.getMaxPackedValue()); + } else if (pos == 0) { + return sortableBytesToLong(pointTree.getMinPackedValue()); + } else if (pointTree.moveToChild()) { + return intersectValueByPos(pointTree, pos); + } else { + return reverse == false + ? sortableBytesToLong(pointTree.getMaxPackedValue()) + : sortableBytesToLong(pointTree.getMinPackedValue()); } } @@ -405,5 +394,276 @@ public abstract class NumericComparator extends FieldComparato protected abstract long bottomAsComparableLong(); protected abstract long topAsComparableLong(); + + class DisjunctionBuildVisitor extends RangeVisitor { + + final Deque disis = new ArrayDeque<>(); + // most competitive entry stored last. + final Consumer adder = + reverse == false ? disis::addFirst : disis::addLast; + + final int minBlockLength = minBlockLength(); + + final LSBRadixSorter sorter = new LSBRadixSorter(); + int[] docs = IntsRef.EMPTY_INTS; + int index = 0; + int blockMaxDoc = -1; + boolean docsInOrder = true; + long blockMinValue = Long.MAX_VALUE; + long blockMaxValue = Long.MIN_VALUE; + + private DisjunctionBuildVisitor() { + super(minValueAsLong, maxValueAsLong, maxDocVisited); + } + + @Override + public void grow(int count) { + docs = ArrayUtil.grow(docs, index + count + 1); + } + + @Override + protected void consumeDoc(int doc) { + docs[index++] = doc; + if (doc >= blockMaxDoc) { + blockMaxDoc = doc; + } else { + docsInOrder = false; + } + } + + void intersectLeaves(PointValues.PointTree pointTree) throws IOException { + PointValues.Relation r = + compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + switch (r) { + case CELL_INSIDE_QUERY, CELL_CROSSES_QUERY -> { + if (pointTree.moveToChild()) { + do { + intersectLeaves(pointTree); + } while (pointTree.moveToSibling()); + pointTree.moveToParent(); + } else { + if (r == PointValues.Relation.CELL_CROSSES_QUERY) { + pointTree.visitDocValues(this); + } else { + pointTree.visitDocIDs(this); + } + updateMinMax( + sortableBytesToLong(pointTree.getMinPackedValue()), + sortableBytesToLong(pointTree.getMaxPackedValue())); + } + } + case CELL_OUTSIDE_QUERY -> {} + default -> throw new IllegalStateException("unreachable code"); + } + } + + void updateMinMax(long leafMinValue, long leafMaxValue) throws IOException { + this.blockMinValue = Math.min(blockMinValue, leafMinValue); + this.blockMaxValue = Math.max(blockMaxValue, leafMaxValue); + if (index >= minBlockLength) { + update(); + this.blockMinValue = Long.MAX_VALUE; + this.blockMaxValue = Long.MIN_VALUE; + } + } + + void update() throws IOException { + if (blockMinValue > blockMaxValue) { + return; + } + long mostCompetitiveValue = + reverse == false + ? Math.max(blockMinValue, minValueAsLong) + : Math.min(blockMaxValue, maxValueAsLong); + + if (docsInOrder == false) { + sorter.sort(PackedInts.bitsRequired(blockMaxDoc), docs, index); + } + docs[index] = DocIdSetIterator.NO_MORE_DOCS; + DocIdSetIterator iter = new IntArrayDocIdSet(docs, index).iterator(); + adder.accept(new DisiAndMostCompetitiveValue(iter, mostCompetitiveValue)); + docs = IntsRef.EMPTY_INTS; + index = 0; + blockMaxDoc = -1; + docsInOrder = true; + } + + DocIdSetIterator generateCompetitiveIterator() throws IOException { + intersectLeaves(pointValues.getPointTree()); + update(); + + if (disis.isEmpty()) { + return DocIdSetIterator.empty(); + } + assert assertMostCompetitiveValuesSorted(disis); + + PriorityQueue disjunction = + new PriorityQueue<>(disis.size()) { + @Override + protected boolean lessThan( + DisiAndMostCompetitiveValue a, DisiAndMostCompetitiveValue b) { + return a.disi.docID() < b.disi.docID(); + } + }; + disjunction.addAll(disis); + + return new CompetitiveIterator(maxDoc, disis, disjunction); + } + + /** + * Used for assert. When reverse is false, smaller values are more competitive, so + * mostCompetitiveValues should be in desc order. + */ + private boolean assertMostCompetitiveValuesSorted(Deque deque) { + long lastValue = reverse == false ? Long.MAX_VALUE : Long.MIN_VALUE; + for (DisiAndMostCompetitiveValue value : deque) { + if (reverse == false) { + assert value.mostCompetitiveValue <= lastValue + : deque.stream().map(d -> d.mostCompetitiveValue).toList().toString(); + } else { + assert value.mostCompetitiveValue >= lastValue + : deque.stream().map(d -> d.mostCompetitiveValue).toList().toString(); + } + lastValue = value.mostCompetitiveValue; + } + return true; + } + + private int minBlockLength() { + // bottom value can be much more competitive than thresholdAsLong, recompute the cost. + long cost = + pointValues.estimatePointCount(new RangeVisitor(minValueAsLong, maxValueAsLong, -1)); + long disjunctionClause = Math.min(MAX_DISJUNCTION_CLAUSE, cost / 512 + 1); + return Math.toIntExact(cost / disjunctionClause); + } + } + } + + private class RangeVisitor implements PointValues.IntersectVisitor { + + private final long minInclusive; + private final long maxInclusive; + private final int docLowerBound; + + private RangeVisitor(long minInclusive, long maxInclusive, int docLowerBound) { + this.minInclusive = minInclusive; + this.maxInclusive = maxInclusive; + this.docLowerBound = docLowerBound; + } + + @Override + public void visit(int docID) throws IOException { + if (docID <= docLowerBound) { + return; // Already visited or skipped + } + consumeDoc(docID); + } + + @Override + public void visit(int docID, byte[] packedValue) throws IOException { + if (docID <= docLowerBound) { + return; // already visited or skipped + } + long l = sortableBytesToLong(packedValue); + if (l >= minInclusive && l <= maxInclusive) { + consumeDoc(docID); + } + } + + @Override + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + long l = sortableBytesToLong(packedValue); + if (l >= minInclusive && l <= maxInclusive) { + int doc = docLowerBound >= 0 ? iterator.advance(docLowerBound) : iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + consumeDoc(doc); + doc = iterator.nextDoc(); + } + } + } + + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + int doc = docLowerBound >= 0 ? iterator.advance(docLowerBound) : iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + consumeDoc(doc); + doc = iterator.nextDoc(); + } + } + + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + long min = sortableBytesToLong(minPackedValue); + long max = sortableBytesToLong(maxPackedValue); + + if (min > maxInclusive || max < minInclusive) { + // 1. cmp ==0 and pruning==Pruning.GREATER_THAN_OR_EQUAL_TO : if the sort is + // ascending then maxValueAsLong is bottom's next less value, so it is competitive + // 2. cmp ==0 and pruning==Pruning.GREATER_THAN: maxValueAsLong equals to + // bottom, but there are multiple comparators, so it could be competitive + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + + if (min < minInclusive || max > maxInclusive) { + return PointValues.Relation.CELL_CROSSES_QUERY; + } + return PointValues.Relation.CELL_INSIDE_QUERY; + } + + void consumeDoc(int doc) { + throw new UnsupportedOperationException(); + } + } + + private record DisiAndMostCompetitiveValue(DocIdSetIterator disi, long mostCompetitiveValue) {} + + private static class CompetitiveIterator extends DocIdSetIterator { + + private final int maxDoc; + private int doc = -1; + private final Deque disis; + private final PriorityQueue disjunction; + + CompetitiveIterator( + int maxDoc, + Deque disis, + PriorityQueue disjunction) { + this.maxDoc = maxDoc; + this.disis = disis; + this.disjunction = disjunction; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + return advance(docID() + 1); + } + + @Override + public int advance(int target) throws IOException { + if (target >= maxDoc) { + return doc = NO_MORE_DOCS; + } else { + DisiAndMostCompetitiveValue top = disjunction.top(); + if (top == null) { + // priority queue is empty, none of the remaining documents are competitive + return doc = NO_MORE_DOCS; + } + while (top.disi.docID() < target) { + top.disi.advance(target); + top = disjunction.updateTop(); + } + return doc = top.disi.docID(); + } + } + + @Override + public long cost() { + return maxDoc; + } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java b/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java index 6475c745572..eb4b93f499e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java @@ -21,7 +21,12 @@ import java.util.Arrays; import org.apache.lucene.search.DocIdSet; import org.apache.lucene.search.DocIdSetIterator; -final class IntArrayDocIdSet extends DocIdSet { +/** + * A doc id set based on sorted int array. + * + * @lucene.internal + */ +public final class IntArrayDocIdSet extends DocIdSet { private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IntArrayDocIdSet.class); @@ -29,12 +34,32 @@ final class IntArrayDocIdSet extends DocIdSet { private final int[] docs; private final int length; - IntArrayDocIdSet(int[] docs, int length) { - if (docs[length] != DocIdSetIterator.NO_MORE_DOCS) { + /** + * Build an IntArrayDocIdSet by an int array and len. + * + * @param docs A docs array whose length need to be greater than the param len. It needs to be + * sorted from 0(inclusive) to the len(exclusive), and the len-th doc in docs need to be + * {@link DocIdSetIterator#NO_MORE_DOCS}. + * @param len The valid docs length in array. + */ + public IntArrayDocIdSet(int[] docs, int len) { + if (docs[len] != DocIdSetIterator.NO_MORE_DOCS) { throw new IllegalArgumentException(); } + assert assertArraySorted(docs, len) + : "IntArrayDocIdSet need docs to be sorted" + + Arrays.toString(ArrayUtil.copyOfSubArray(docs, 0, len)); this.docs = docs; - this.length = length; + this.length = len; + } + + private static boolean assertArraySorted(int[] docs, int length) { + for (int i = 1; i < length; i++) { + if (docs[i] < docs[i - 1]) { + return false; + } + } + return true; } @Override