diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 64d6adf7a27..38c001cb890 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -11,6 +11,8 @@ Improvements * LUCENE-8673: Use radix partitioning when merging dimensional points instead of sorting all dimensions before hand. (Ignacio Vera, Adrien Grand) +* LUCENE-8687: Optimise radix partitioning for points on heap. (Ignacio Vera) + Other * LUCENE-8680: Refactor EdgeTree#relateTriangle method. (Ignacio Vera) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java index a6276ea4f7b..84642842b70 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java @@ -179,16 +179,8 @@ final class SimpleTextBKDWriter implements Closeable { // dimensional values (numDims * bytesPerDim) + docID (int) bytesPerDoc = packedBytesLength + Integer.BYTES; - // As we recurse, we compute temporary partitions of the data, halving the - // number of points at each recursion. Once there are few enough points, - // we can switch to sorting in heap instead of offline (on disk). At any - // time in the recursion, we hold the number of points at that level, plus - // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X - // what that level would consume, so we multiply by 0.5 to convert from - // bytes to points here. In addition the radix partitioning may sort on memory - // double of this size so we multiply by another 0.5. - - maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims)); + // Maximum number of points we hold in memory at any time + maxPointsSortInHeap = (int) ((maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims)); // Finally, we must be able to hold at least the leaf node in heap during build: if (maxPointsSortInHeap < maxPointsInLeafNode) { @@ -577,25 +569,21 @@ final class SimpleTextBKDWriter implements Closeable { // encoding and not have our own ByteSequencesReader/Writer /** Sort the heap writer by the specified dim */ - private void sortHeapPointWriter(final HeapPointWriter writer, int dim) { - final int pointCount = Math.toIntExact(writer.count()); + private void sortHeapPointWriter(final HeapPointWriter writer, int from, int to, int dim, int commonPrefixLength) { // Tie-break by docID: - - // No need to tie break on ord, for the case where the same doc has the same value in a given dimension indexed more than once: it - // can't matter at search time since we don't write ords into the index: - new MSBRadixSorter(bytesPerDim + Integer.BYTES) { + new MSBRadixSorter(bytesPerDim + Integer.BYTES - commonPrefixLength) { @Override protected int byteAt(int i, int k) { assert k >= 0; - if (k < bytesPerDim) { + if (k + commonPrefixLength < bytesPerDim) { // dim bytes int block = i / writer.valuesPerBlock; int index = i % writer.valuesPerBlock; - return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k] & 0xff; + return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k + commonPrefixLength] & 0xff; } else { // doc id - int s = 3 - (k - bytesPerDim); + int s = 3 - (k + commonPrefixLength - bytesPerDim); return (writer.docIDs[i] >>> (s * 8)) & 0xff; } } @@ -605,7 +593,7 @@ final class SimpleTextBKDWriter implements Closeable { writer.swap(i, j); } - }.sort(0, pointCount); + }.sort(from, to); } private void checkMaxLeafNodeCount(int numLeaves) { @@ -625,14 +613,13 @@ final class SimpleTextBKDWriter implements Closeable { throw new IllegalStateException("already finished"); } - PointWriter data; - + BKDRadixSelector.PathSlice writer; if (offlinePointWriter != null) { offlinePointWriter.close(); - data = offlinePointWriter; + writer = new BKDRadixSelector.PathSlice(offlinePointWriter, 0, pointCount); tempInput = null; } else { - data = heapPointWriter; + writer = new BKDRadixSelector.PathSlice(heapPointWriter, 0, pointCount); heapPointWriter = null; } @@ -671,7 +658,7 @@ final class SimpleTextBKDWriter implements Closeable { try { - build(1, numLeaves, data, out, + build(1, numLeaves, writer, out, radixSelector, minPackedValue, maxPackedValue, splitPackedValues, leafBlockFPs); @@ -1017,7 +1004,7 @@ final class SimpleTextBKDWriter implements Closeable { /** The array (sized numDims) of PathSlice describe the cell we have currently recursed to. */ private void build(int nodeID, int leafNodeOffset, - PointWriter data, + BKDRadixSelector.PathSlice points, IndexOutput out, BKDRadixSelector radixSelector, byte[] minPackedValue, byte[] maxPackedValue, @@ -1030,14 +1017,17 @@ final class SimpleTextBKDWriter implements Closeable { // We can write the block in any order so by default we write it sorted by the dimension that has the // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient - if (data instanceof HeapPointWriter == false) { + HeapPointWriter heapSource; + if (points.writer instanceof HeapPointWriter == false) { // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer - data = switchToHeap(data); + heapSource = switchToHeap(points.writer); + } else { + heapSource = (HeapPointWriter) points.writer; } - // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point: - HeapPointWriter heapSource = (HeapPointWriter) data; + int from = Math.toIntExact(points.start); + int to = Math.toIntExact(points.start + points.count); //we store common prefix on scratch1 computeCommonPrefixLength(heapSource, scratch1); @@ -1068,7 +1058,8 @@ final class SimpleTextBKDWriter implements Closeable { } } - sortHeapPointWriter(heapSource, sortedDim); + // sort the chosen dimension + sortHeapPointWriter(heapSource, from, to, sortedDim, commonPrefixLengths[sortedDim]); // Save the block file pointer: leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer(); @@ -1076,9 +1067,9 @@ final class SimpleTextBKDWriter implements Closeable { // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o // loading the values: - int count = Math.toIntExact(heapSource.count()); + int count = to - from; assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset; - writeLeafBlockDocs(out, heapSource.docIDs, 0, count); + writeLeafBlockDocs(out, heapSource.docIDs, from, count); // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us // from the index, much like how terms dict does so from the FST: @@ -1093,12 +1084,12 @@ final class SimpleTextBKDWriter implements Closeable { @Override public BytesRef apply(int i) { - heapSource.getPackedValueSlice(i, scratch); + heapSource.getPackedValueSlice(from + i, scratch); return scratch; } }; assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues, - heapSource.docIDs, Math.toIntExact(0)); + heapSource.docIDs, from); writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues); } else { @@ -1111,26 +1102,23 @@ final class SimpleTextBKDWriter implements Closeable { splitDim = 0; } - assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length; // How many points will be in the left tree: - long rightCount = data.count() / 2; - long leftCount = data.count() - rightCount; + long rightCount = points.count / 2; + long leftCount = points.count - rightCount; - PointWriter leftPointWriter; - PointWriter rightPointWriter; - byte[] splitValue; - - try (PointWriter leftPointWriter2 = getPointWriter(leftCount, "left" + splitDim); - PointWriter rightPointWriter2 = getPointWriter(rightCount, "right" + splitDim)) { - splitValue = radixSelector.select(data, leftPointWriter2, rightPointWriter2, 0, data.count(), leftCount, splitDim); - leftPointWriter = leftPointWriter2; - rightPointWriter = rightPointWriter2; - } catch (Throwable t) { - throw verifyChecksum(t, data); + int commonPrefixLen = FutureArrays.mismatch(minPackedValue, splitDim * bytesPerDim, + splitDim * bytesPerDim + bytesPerDim, maxPackedValue, splitDim * bytesPerDim, + splitDim * bytesPerDim + bytesPerDim); + if (commonPrefixLen == -1) { + commonPrefixLen = bytesPerDim; } + BKDRadixSelector.PathSlice[] pathSlices = new BKDRadixSelector.PathSlice[2]; + + byte[] splitValue = radixSelector.select(points, pathSlices, points.start, points.start + points.count, points.start + leftCount, splitDim, commonPrefixLen); + int address = nodeID * (1 + bytesPerDim); splitPackedValues[address] = (byte) splitDim; System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim); @@ -1144,15 +1132,13 @@ final class SimpleTextBKDWriter implements Closeable { System.arraycopy(splitValue, 0, minSplitPackedValue, splitDim * bytesPerDim, bytesPerDim); System.arraycopy(splitValue, 0, maxSplitPackedValue, splitDim * bytesPerDim, bytesPerDim); - - // Recurse on left tree: - build(2*nodeID, leafNodeOffset, leftPointWriter, out, radixSelector, + build(2*nodeID, leafNodeOffset, pathSlices[0], out, radixSelector, minPackedValue, maxSplitPackedValue, splitPackedValues, leafBlockFPs); // TODO: we could "tail recurse" here? have our parent discard its refs as we recurse right? // Recurse on right tree: - build(2*nodeID+1, leafNodeOffset, rightPointWriter, out, radixSelector, + build(2*nodeID+1, leafNodeOffset, pathSlices[1], out, radixSelector, minSplitPackedValue, maxPackedValue, splitPackedValues, leafBlockFPs); } } @@ -1212,15 +1198,6 @@ final class SimpleTextBKDWriter implements Closeable { return true; } - PointWriter getPointWriter(long count, String desc) throws IOException { - if (count <= maxPointsSortInHeap) { - int size = Math.toIntExact(count); - return new HeapPointWriter(size, size, packedBytesLength); - } else { - return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count); - } - } - private void write(IndexOutput out, String s) throws IOException { SimpleTextUtil.write(out, s, scratch); } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java index 8d6c852fc45..3bc025c0ee9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java @@ -66,7 +66,7 @@ public final class BKDRadixSelector { this.bytesPerDim = bytesPerDim; this.packedBytesLength = numDim * bytesPerDim; this.bytesSorted = bytesPerDim + Integer.BYTES; - this.maxPointsSortInHeap = 2 * maxPointsSortInHeap; + this.maxPointsSortInHeap = maxPointsSortInHeap; int numberOfPointsOffline = MAX_SIZE_OFFLINE_BUFFER / (packedBytesLength + Integer.BYTES); this.offlineBuffer = new byte[numberOfPointsOffline * (packedBytesLength + Integer.BYTES)]; this.partitionBucket = new int[bytesSorted]; @@ -77,35 +77,54 @@ public final class BKDRadixSelector { } /** + * It uses the provided {@code points} from the given {@code from} to the given {@code to} + * to populate the {@code partitionSlices} array holder (length > 1) with two path slices + * so the path slice at position 0 contains {@code partition - from} points + * where the value of the {@code dim} is lower or equal to the {@code to -from} + * points on the slice at position 1. * - * Method to partition the input data. It returns the value of the dimension where - * the split happens. The method destroys the original writer. + * The {@code dimCommonPrefix} provides a hint for the length of the common prefix length for + * the {@code dim} where are partitioning the points. * + * It return the value of the {@code dim} at the partition point. + * + * If the provided {@code points} is wrapping an {@link OfflinePointWriter}, the + * writer is destroyed in the process to save disk space. */ - public byte[] select(PointWriter points, PointWriter left, PointWriter right, long from, long to, long partitionPoint, int dim) throws IOException { + public byte[] select(PathSlice points, PathSlice[] partitionSlices, long from, long to, long partitionPoint, int dim, int dimCommonPrefix) throws IOException { checkArgs(from, to, partitionPoint); + assert partitionSlices.length > 1; + //If we are on heap then we just select on heap - if (points instanceof HeapPointWriter) { - return heapSelect((HeapPointWriter) points, left, right, dim, Math.toIntExact(from), Math.toIntExact(to), Math.toIntExact(partitionPoint), 0); + if (points.writer instanceof HeapPointWriter) { + byte[] partition = heapRadixSelect((HeapPointWriter) points.writer, dim, Math.toIntExact(from), Math.toIntExact(to), Math.toIntExact(partitionPoint), dimCommonPrefix); + partitionSlices[0] = new PathSlice(points.writer, from, partitionPoint - from); + partitionSlices[1] = new PathSlice(points.writer, partitionPoint, to - partitionPoint); + return partition; } //reset histogram for (int i = 0; i < bytesSorted; i++) { Arrays.fill(histogram[i], 0); } - OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points; + OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points.writer; - //find common prefix, it does already set histogram values if needed - int commonPrefix = findCommonPrefix(offlinePointWriter, from, to, dim); + //find common prefix from dimCommonPrefix, it does already set histogram values if needed + int commonPrefix = findCommonPrefix(offlinePointWriter, from, to, dim, dimCommonPrefix); - //if all equals we just partition the data - if (commonPrefix == bytesSorted) { - partition(offlinePointWriter, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint); - return partitionPointFromCommonPrefix(); + try (PointWriter left = getPointWriter(partitionPoint - from, "left" + dim); + PointWriter right = getPointWriter(to - partitionPoint, "right" + dim)) { + partitionSlices[0] = new PathSlice(left, 0, partitionPoint - from); + partitionSlices[1] = new PathSlice(right, 0, to - partitionPoint); + //if all equals we just partition the points + if (commonPrefix == bytesSorted) { + offlinePartition(offlinePointWriter, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint); + return partitionPointFromCommonPrefix(); + } + //let's rock'n'roll + return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, commonPrefix, dim); } - //let's rock'n'roll - return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, commonPrefix, dim); } void checkArgs(long from, long to, long partitionPoint) { @@ -117,11 +136,12 @@ public final class BKDRadixSelector { } } - private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim) throws IOException{ + private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{ //find common prefix byte[] commonPrefix = new byte[bytesSorted]; int commonPrefixPosition = bytesSorted; try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { + assert commonPrefixPosition > dimCommonPrefix; reader.next(); reader.packedValueWithDocId(bytesRef1); // copy dimension @@ -131,21 +151,22 @@ public final class BKDRadixSelector { for (long i = from + 1; i< to; i++) { reader.next(); reader.packedValueWithDocId(bytesRef1); - int startIndex = dim * bytesPerDim; - int endIndex = (commonPrefixPosition > bytesPerDim) ? startIndex + bytesPerDim : startIndex + commonPrefixPosition; - int j = FutureArrays.mismatch(commonPrefix, 0, endIndex - startIndex, bytesRef1.bytes, bytesRef1.offset + startIndex, bytesRef1.offset + endIndex); + int startIndex = (dimCommonPrefix > bytesPerDim) ? bytesPerDim : dimCommonPrefix; + int endIndex = (commonPrefixPosition > bytesPerDim) ? bytesPerDim : commonPrefixPosition; + int j = FutureArrays.mismatch(commonPrefix, startIndex, endIndex, bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim + startIndex, bytesRef1.offset + dim * bytesPerDim + endIndex); if (j == 0) { - return 0; + commonPrefixPosition = dimCommonPrefix; + break; } else if (j == -1) { if (commonPrefixPosition > bytesPerDim) { //tie-break on docID - int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim ); + int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim); if (k != -1) { commonPrefixPosition = bytesPerDim + k; } } } else { - commonPrefixPosition = j; + commonPrefixPosition = dimCommonPrefix + j; } } } @@ -196,33 +217,29 @@ public final class BKDRadixSelector { //special case when be have lot of points that are equal if (commonPrefix == bytesSorted - 1) { long tieBreakCount =(partitionPoint - from - leftCount); - partition(points, left, right, null, from, to, dim, commonPrefix, tieBreakCount); + offlinePartition(points, left, right, null, from, to, dim, commonPrefix, tieBreakCount); return partitionPointFromCommonPrefix(); } //create the delta points writer PointWriter deltaPoints; - if (delta <= maxPointsSortInHeap) { - deltaPoints = new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength); - } else { - deltaPoints = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta); + try (PointWriter tempDeltaPoints = getDeltaPointWriter(left, right, delta, iteration)) { + //divide the points. This actually destroys the current writer + offlinePartition(points, left, right, tempDeltaPoints, from, to, dim, commonPrefix, 0); + deltaPoints = tempDeltaPoints; } - //divide the points. This actually destroys the current writer - partition(points, left, right, deltaPoints, from, to, dim, commonPrefix, 0); - //close delta point writer - deltaPoints.close(); long newPartitionPoint = partitionPoint - from - leftCount; if (deltaPoints instanceof HeapPointWriter) { - return heapSelect((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix); + return heapPartition((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix); } else { return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim); } } - private void partition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints, - long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException { + private void offlinePartition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints, + long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException { assert bytePosition == bytesSorted -1 || deltaPoints != null; long tiebreakCounter = 0; try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { @@ -269,7 +286,24 @@ public final class BKDRadixSelector { return partition; } - private byte[] heapSelect(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException { + private byte[] heapPartition(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException { + + byte[] partition = heapRadixSelect(points, dim, from, to, partitionPoint, commonPrefix); + + for (int i = from; i < to; i++) { + points.getPackedValueSlice(i, bytesRef1); + int docID = points.docIDs[i]; + if (i < partitionPoint) { + left.append(bytesRef1, docID); + } else { + right.append(bytesRef1, docID); + } + } + + return partition; + } + + private byte[] heapRadixSelect(HeapPointWriter points, int dim, int from, int to, int partitionPoint, int commonPrefix) { final int offset = dim * bytesPerDim + commonPrefix; new RadixSelector(bytesSorted - commonPrefix) { @@ -294,18 +328,59 @@ public final class BKDRadixSelector { } }.select(from, to, partitionPoint); - for (int i = from; i < to; i++) { - points.getPackedValueSlice(i, bytesRef1); - int docID = points.docIDs[i]; - if (i < partitionPoint) { - left.append(bytesRef1, docID); - } else { - right.append(bytesRef1, docID); - } - } byte[] partition = new byte[bytesPerDim]; points.getPackedValueSlice(partitionPoint, bytesRef1); System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, partition, 0, bytesPerDim); return partition; } + + private PointWriter getDeltaPointWriter(PointWriter left, PointWriter right, long delta, int iteration) throws IOException { + if (delta <= getMaxPointsSortInHeap(left, right)) { + return new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength); + } else { + return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta); + } + } + + private int getMaxPointsSortInHeap(PointWriter left, PointWriter right) { + int pointsUsed = 0; + if (left instanceof HeapPointWriter) { + pointsUsed += ((HeapPointWriter) left).maxSize; + } + if (right instanceof HeapPointWriter) { + pointsUsed += ((HeapPointWriter) right).maxSize; + } + assert maxPointsSortInHeap >= pointsUsed; + return maxPointsSortInHeap - pointsUsed; + } + + PointWriter getPointWriter(long count, String desc) throws IOException { + //As we recurse, we hold two on-heap point writers at any point. Therefore the + //max size for these objects is half of the total points we can have on-heap. + if (count <= maxPointsSortInHeap / 2) { + int size = Math.toIntExact(count); + return new HeapPointWriter(size, size, packedBytesLength); + } else { + return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count); + } + } + + /** Sliced reference to points in an PointWriter. */ + public static final class PathSlice { + public final PointWriter writer; + public final long start; + public final long count; + + public PathSlice(PointWriter writer, long start, long count) { + this.writer = writer; + this.start = start; + this.count = count; + } + + @Override + public String toString() { + return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")"; + } + } + } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java index 8b66de470c6..654648badbb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java @@ -173,17 +173,8 @@ public class BKDWriter implements Closeable { // dimensional values (numDims * bytesPerDim) + docID (int) bytesPerDoc = packedBytesLength + Integer.BYTES; - - // As we recurse, we compute temporary partitions of the data, halving the - // number of points at each recursion. Once there are few enough points, - // we can switch to sorting in heap instead of offline (on disk). At any - // time in the recursion, we hold the number of points at that level, plus - // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X - // what that level would consume, so we multiply by 0.5 to convert from - // bytes to points here. In addition the radix partitioning may sort on memory - // double of this size so we multiply by another 0.5. - - maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc)); + // Maximum number of points we hold in memory at any time + maxPointsSortInHeap = (int) ((maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc)); // Finally, we must be able to hold at least the leaf node in heap during build: if (maxPointsSortInHeap < maxPointsInLeafNode) { @@ -403,7 +394,6 @@ public class BKDWriter implements Closeable { } } - /* In the 2+D case, we recursively pick the split dimension, compute the * median value and partition other values around it. */ private long writeFieldNDims(IndexOutput out, String fieldName, MutablePointValues values) throws IOException { @@ -723,7 +713,7 @@ public class BKDWriter implements Closeable { // encoding and not have our own ByteSequencesReader/Writer /** Sort the heap writer by the specified dim */ - private void sortHeapPointWriter(final HeapPointWriter writer, int pointCount, int dim, int commonPrefixLength) { + private void sortHeapPointWriter(final HeapPointWriter writer, int from, int to, int dim, int commonPrefixLength) { // Tie-break by docID: new MSBRadixSorter(bytesPerDim + Integer.BYTES - commonPrefixLength) { @@ -747,7 +737,7 @@ public class BKDWriter implements Closeable { writer.swap(i, j); } - }.sort(0, pointCount); + }.sort(from, to); } // useful for debugging: @@ -785,20 +775,20 @@ public class BKDWriter implements Closeable { throw new IllegalStateException("already finished"); } - PointWriter writer; - if (offlinePointWriter != null) { - offlinePointWriter.close(); - writer = offlinePointWriter; - tempInput = null; - } else { - writer = heapPointWriter; - heapPointWriter = null; - } - if (pointCount == 0) { throw new IllegalStateException("must index at least one point"); } + BKDRadixSelector.PathSlice points; + if (offlinePointWriter != null) { + offlinePointWriter.close(); + points = new BKDRadixSelector.PathSlice(offlinePointWriter, 0, pointCount); + tempInput = null; + } else { + points = new BKDRadixSelector.PathSlice(heapPointWriter, 0, pointCount); + heapPointWriter = null; + } + long countPerLeaf = pointCount; long innerNodeCount = 1; @@ -830,7 +820,7 @@ public class BKDWriter implements Closeable { try { final int[] parentSplits = new int[numIndexDims]; - build(1, numLeaves, writer, + build(1, numLeaves, points, out, radixSelector, minPackedValue, maxPackedValue, parentSplits, @@ -1435,7 +1425,7 @@ public class BKDWriter implements Closeable { /** The point writer contains the data that is going to be splitted using radix selection. /* This method is used when we are merging previously written segments, in the numDims > 1 case. */ private void build(int nodeID, int leafNodeOffset, - PointWriter points, + BKDRadixSelector.PathSlice points, IndexOutput out, BKDRadixSelector radixSelector, byte[] minPackedValue, byte[] maxPackedValue, @@ -1448,18 +1438,19 @@ public class BKDWriter implements Closeable { // Leaf node: write block // We can write the block in any order so by default we write it sorted by the dimension that has the // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient - - if (points instanceof HeapPointWriter == false) { + HeapPointWriter heapSource; + if (points.writer instanceof HeapPointWriter == false) { // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer - points = switchToHeap(points); + heapSource = switchToHeap(points.writer); + } else { + heapSource = (HeapPointWriter) points.writer; } - // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point: - HeapPointWriter heapSource = (HeapPointWriter) points; - + int from = Math.toIntExact(points.start); + int to = Math.toIntExact(points.start + points.count); //we store common prefix on scratch1 - computeCommonPrefixLength(heapSource, scratch1); + computeCommonPrefixLength(heapSource, scratch1, from, to); int sortedDim = 0; int sortedDimCardinality = Integer.MAX_VALUE; @@ -1474,7 +1465,7 @@ public class BKDWriter implements Closeable { int prefix = commonPrefixLengths[dim]; if (prefix < bytesPerDim) { int offset = dim * bytesPerDim; - for (int i = 0; i < heapSource.count(); ++i) { + for (int i = from; i < to; ++i) { heapSource.getPackedValueSlice(i, scratchBytesRef1); int bucket = scratchBytesRef1.bytes[scratchBytesRef1.offset + offset + prefix] & 0xff; usedBytes[dim].set(bucket); @@ -1488,7 +1479,7 @@ public class BKDWriter implements Closeable { } // sort the chosen dimension - sortHeapPointWriter(heapSource, Math.toIntExact(heapSource.count()), sortedDim, commonPrefixLengths[sortedDim]); + sortHeapPointWriter(heapSource, from, to, sortedDim, commonPrefixLengths[sortedDim]); // Save the block file pointer: leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer(); @@ -1496,9 +1487,9 @@ public class BKDWriter implements Closeable { // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o // loading the values: - int count = Math.toIntExact(heapSource.count()); + int count = to - from; assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset; - writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(0), count); + writeLeafBlockDocs(out, heapSource.docIDs, from, count); // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us // from the index, much like how terms dict does so from the FST: @@ -1516,12 +1507,12 @@ public class BKDWriter implements Closeable { @Override public BytesRef apply(int i) { - heapSource.getPackedValueSlice(Math.toIntExact(i), scratch); + heapSource.getPackedValueSlice(from + i, scratch); return scratch; } }; assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues, - heapSource.docIDs, Math.toIntExact(0)); + heapSource.docIDs, from); writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues); } else { @@ -1534,25 +1525,23 @@ public class BKDWriter implements Closeable { splitDim = 0; } - assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length; // How many points will be in the left tree: - long rightCount = points.count() / 2; - long leftCount = points.count() - rightCount; + long rightCount = points.count / 2; + long leftCount = points.count - rightCount; - PointWriter leftPointWriter; - PointWriter rightPointWriter; - byte[] splitValue; - try (PointWriter tempLeftPointWriter = getPointWriter(leftCount, "left" + splitDim); - PointWriter tempRightPointWriter = getPointWriter(rightCount, "right" + splitDim)) { - splitValue = radixSelector.select(points, tempLeftPointWriter, tempRightPointWriter, 0, points.count(), leftCount, splitDim); - leftPointWriter = tempLeftPointWriter; - rightPointWriter = tempRightPointWriter; - } catch (Throwable t) { - throw verifyChecksum(t, points); + BKDRadixSelector.PathSlice[] slices = new BKDRadixSelector.PathSlice[2]; + + int commonPrefixLen = FutureArrays.mismatch(minPackedValue, splitDim * bytesPerDim, + splitDim * bytesPerDim + bytesPerDim, maxPackedValue, splitDim * bytesPerDim, + splitDim * bytesPerDim + bytesPerDim); + if (commonPrefixLen == -1) { + commonPrefixLen = bytesPerDim; } + byte[] splitValue = radixSelector.select(points, slices, points.start, points.start + points.count, points.start + leftCount, splitDim, commonPrefixLen); + int address = nodeID * (1 + bytesPerDim); splitPackedValues[address] = (byte) splitDim; System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim); @@ -1568,12 +1557,12 @@ public class BKDWriter implements Closeable { parentSplits[splitDim]++; // Recurse on left tree: - build(2 * nodeID, leafNodeOffset, leftPointWriter, + build(2 * nodeID, leafNodeOffset, slices[0], out, radixSelector, minPackedValue, maxSplitPackedValue, parentSplits, splitPackedValues, leafBlockFPs); // Recurse on right tree: - build(2 * nodeID + 1, leafNodeOffset, rightPointWriter, + build(2 * nodeID + 1, leafNodeOffset, slices[1], out, radixSelector, minSplitPackedValue, maxPackedValue , parentSplits, splitPackedValues, leafBlockFPs); @@ -1581,14 +1570,14 @@ public class BKDWriter implements Closeable { } } - private void computeCommonPrefixLength(HeapPointWriter heapPointWriter, byte[] commonPrefix) { + private void computeCommonPrefixLength(HeapPointWriter heapPointWriter, byte[] commonPrefix, int from, int to) { Arrays.fill(commonPrefixLengths, bytesPerDim); scratchBytesRef1.length = packedBytesLength; - heapPointWriter.getPackedValueSlice(0, scratchBytesRef1); + heapPointWriter.getPackedValueSlice(from, scratchBytesRef1); for (int dim = 0; dim < numDataDims; dim++) { System.arraycopy(scratchBytesRef1.bytes, scratchBytesRef1.offset + dim * bytesPerDim, commonPrefix, dim * bytesPerDim, bytesPerDim); } - for (int i = 1; i < heapPointWriter.count(); i++) { + for (int i = from + 1; i < to; i++) { heapPointWriter.getPackedValueSlice(i, scratchBytesRef1); for (int dim = 0; dim < numDataDims; dim++) { if (commonPrefixLengths[dim] != 0) { @@ -1635,14 +1624,4 @@ public class BKDWriter implements Closeable { System.arraycopy(packedValue, packedValueOffset, lastPackedValue, 0, packedBytesLength); return true; } - - PointWriter getPointWriter(long count, String desc) throws IOException { - if (count <= maxPointsSortInHeap) { - int size = Math.toIntExact(count); - return new HeapPointWriter(size, size, packedBytesLength); - } else { - return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count); - } - } - } diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java index 01d05a0f0d5..6e3863f9cb5 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java @@ -971,7 +971,7 @@ public class TestBKD extends LuceneTestCase { public IndexOutput createTempOutput(String prefix, String suffix, IOContext context) throws IOException { IndexOutput out = in.createTempOutput(prefix, suffix, context); //System.out.println("prefix=" + prefix + " suffix=" + suffix); - if (corrupted == false && suffix.equals("bkd_left1")) { + if (corrupted == false && suffix.equals("bkd_left0")) { //System.out.println("now corrupt byte=" + x + " prefix=" + prefix + " suffix=" + suffix); corrupted = true; return new CorruptingIndexOutput(dir0, 22072, out); diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java index ca61b02b3f0..558b9f2dc8d 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java @@ -48,7 +48,8 @@ public class TestBKDRadixSelector extends LuceneTestCase { NumericUtils.intToSortableBytes(4, bytes, 0); points.append(bytes, 3); points.close(); - verify(dir, points, dimensions, 0, values, middle, packedLength, bytesPerDimensions, 0); + PointWriter copy = copyPoints(dir,points, packedLength); + verify(dir, copy, dimensions, 0, values, middle, packedLength, bytesPerDimensions, 0); dir.close(); } @@ -183,24 +184,31 @@ public class TestBKDRadixSelector extends LuceneTestCase { private void verify(Directory dir, PointWriter points, int dimensions, long start, long end, long middle, int packedLength, int bytesPerDimensions, int sortedOnHeap) throws IOException{ for (int splitDim =0; splitDim < dimensions; splitDim++) { PointWriter copy = copyPoints(dir, points, packedLength); - PointWriter leftPointWriter = getRandomPointWriter(dir, middle - start, packedLength); - PointWriter rightPointWriter = getRandomPointWriter(dir, end - middle, packedLength); + BKDRadixSelector.PathSlice[] slices = new BKDRadixSelector.PathSlice[2]; BKDRadixSelector radixSelector = new BKDRadixSelector(dimensions, bytesPerDimensions, sortedOnHeap, dir, "test"); - byte[] partitionPoint = radixSelector.select(copy, leftPointWriter, rightPointWriter, start, end, middle, splitDim); - leftPointWriter.close(); - rightPointWriter.close(); - byte[] max = getMax(leftPointWriter, middle - start, bytesPerDimensions, splitDim); - byte[] min = getMin(rightPointWriter, end - middle, bytesPerDimensions, splitDim); + BKDRadixSelector.PathSlice copySlice = new BKDRadixSelector.PathSlice(copy, 0, copy.count()); + byte[] pointsMax = getMax(copySlice, bytesPerDimensions, splitDim); + byte[] pointsMin = getMin(copySlice, bytesPerDimensions, splitDim); + int commonPrefixLength = FutureArrays.mismatch(pointsMin, 0, bytesPerDimensions, pointsMax, 0, bytesPerDimensions); + if (commonPrefixLength == -1) { + commonPrefixLength = bytesPerDimensions; + } + int commonPrefixLengthInput = (random().nextBoolean()) ? commonPrefixLength : commonPrefixLength == 0 ? 0 : random().nextInt(commonPrefixLength); + byte[] partitionPoint = radixSelector.select(copySlice, slices, start, end, middle, splitDim, commonPrefixLengthInput); + assertEquals(middle - start, slices[0].count); + assertEquals(end - middle, slices[1].count); + byte[] max = getMax(slices[0], bytesPerDimensions, splitDim); + byte[] min = getMin(slices[1], bytesPerDimensions, splitDim); int cmp = FutureArrays.compareUnsigned(max, 0, bytesPerDimensions, min, 0, bytesPerDimensions); assertTrue(cmp <= 0); if (cmp == 0) { - int maxDocID = getMaxDocId(leftPointWriter, middle - start, bytesPerDimensions, splitDim, partitionPoint); - int minDocId = getMinDocId(rightPointWriter, end - middle, bytesPerDimensions, splitDim, partitionPoint); + int maxDocID = getMaxDocId(slices[0], bytesPerDimensions, splitDim, partitionPoint); + int minDocId = getMinDocId(slices[1], bytesPerDimensions, splitDim, partitionPoint); assertTrue(minDocId >= maxDocID); } assertTrue(Arrays.equals(partitionPoint, min)); - leftPointWriter.destroy(); - rightPointWriter.destroy(); + slices[0].writer.destroy(); + slices[1].writer.destroy(); } points.destroy(); } @@ -236,10 +244,10 @@ public class TestBKDRadixSelector extends LuceneTestCase { return dir; } - private byte[] getMin(PointWriter p, long size, int bytesPerDimension, int dimension) throws IOException { + private byte[] getMin(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension) throws IOException { byte[] min = new byte[bytesPerDimension]; Arrays.fill(min, (byte) 0xff); - try (PointReader reader = p.getReader(0, size)) { + try (PointReader reader = p.writer.getReader(p.start, p.count)) { byte[] value = new byte[bytesPerDimension]; BytesRef packedValue = new BytesRef(); while (reader.next()) { @@ -253,9 +261,9 @@ public class TestBKDRadixSelector extends LuceneTestCase { return min; } - private int getMinDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { + private int getMinDocId(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { int docID = Integer.MAX_VALUE; - try (PointReader reader = p.getReader(0, size)) { + try (PointReader reader = p.writer.getReader(p.start, p.count)) { BytesRef packedValue = new BytesRef(); while (reader.next()) { reader.packedValue(packedValue); @@ -271,10 +279,10 @@ public class TestBKDRadixSelector extends LuceneTestCase { return docID; } - private byte[] getMax(PointWriter p, long size, int bytesPerDimension, int dimension) throws IOException { + private byte[] getMax(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension) throws IOException { byte[] max = new byte[bytesPerDimension]; Arrays.fill(max, (byte) 0); - try (PointReader reader = p.getReader(0, size)) { + try (PointReader reader = p.writer.getReader(p.start, p.count)) { byte[] value = new byte[bytesPerDimension]; BytesRef packedValue = new BytesRef(); while (reader.next()) { @@ -288,9 +296,9 @@ public class TestBKDRadixSelector extends LuceneTestCase { return max; } - private int getMaxDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { + private int getMaxDocId(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { int docID = Integer.MIN_VALUE; - try (PointReader reader = p.getReader(0, size)) { + try (PointReader reader = p.writer.getReader(p.start, p.count)) { BytesRef packedValue = new BytesRef(); while (reader.next()) { reader.packedValue(packedValue);