diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 04042ca7c51..4a21f6a57d2 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -16,7 +16,11 @@ Bug fixes: it's inaccessible. (Dawid Weiss) ======================= Lucene 8.1.0 ======================= -(No Changes) + +Improvements + +* LUCENE-8673: Use radix partitioning when merging dimensional points instead + of sorting all dimensions before hand. (Ignacio Vera, Adrien Grand) ======================= Lucene 8.0.0 ======================= diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java index f45209cc143..a6276ea4f7b 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java @@ -20,7 +20,6 @@ import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; import java.util.List; import java.util.function.IntFunction; @@ -36,18 +35,15 @@ import org.apache.lucene.store.TrackingDirectoryWrapper; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; -import org.apache.lucene.util.BytesRefComparator; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FutureArrays; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.LongBitSet; import org.apache.lucene.util.MSBRadixSorter; import org.apache.lucene.util.NumericUtils; -import org.apache.lucene.util.OfflineSorter; +import org.apache.lucene.util.bkd.BKDRadixSelector; import org.apache.lucene.util.bkd.BKDWriter; import org.apache.lucene.util.bkd.HeapPointWriter; import org.apache.lucene.util.bkd.MutablePointsReaderUtils; -import org.apache.lucene.util.bkd.OfflinePointReader; import org.apache.lucene.util.bkd.OfflinePointWriter; import org.apache.lucene.util.bkd.PointReader; import org.apache.lucene.util.bkd.PointWriter; @@ -148,32 +144,15 @@ final class SimpleTextBKDWriter implements Closeable { protected long pointCount; - /** true if we have so many values that we must write ords using long (8 bytes) instead of int (4 bytes) */ - protected final boolean longOrds; - /** An upper bound on how many points the caller will add (includes deletions) */ private final long totalPointCount; - /** True if every document has at most one value. We specialize this case by not bothering to store the ord since it's redundant with docID. */ - protected final boolean singleValuePerDoc; - - /** How much heap OfflineSorter is allowed to use */ - protected final OfflineSorter.BufferSize offlineSorterBufferMB; - - /** How much heap OfflineSorter is allowed to use */ - protected final int offlineSorterMaxTempFiles; private final int maxDoc; - public SimpleTextBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim, - int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount, boolean singleValuePerDoc) throws IOException { - this(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount, singleValuePerDoc, - totalPointCount > Integer.MAX_VALUE, Math.max(1, (long) maxMBSortInHeap), OfflineSorter.MAX_TEMPFILES); - } - private SimpleTextBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim, - int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount, - boolean singleValuePerDoc, boolean longOrds, long offlineSorterBufferMB, int offlineSorterMaxTempFiles) throws IOException { + public SimpleTextBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim, + int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount) throws IOException { verifyParams(numDataDims, numIndexDims, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount); // We use tracking dir to deal with removing files on exception, so each place that // creates temp files doesn't need crazy try/finally/sucess logic: @@ -185,8 +164,6 @@ final class SimpleTextBKDWriter implements Closeable { this.bytesPerDim = bytesPerDim; this.totalPointCount = totalPointCount; this.maxDoc = maxDoc; - this.offlineSorterBufferMB = OfflineSorter.BufferSize.megabytes(offlineSorterBufferMB); - this.offlineSorterMaxTempFiles = offlineSorterMaxTempFiles; docsSeen = new FixedBitSet(maxDoc); packedBytesLength = numDataDims * bytesPerDim; packedIndexBytesLength = numIndexDims * bytesPerDim; @@ -199,21 +176,8 @@ final class SimpleTextBKDWriter implements Closeable { minPackedValue = new byte[packedIndexBytesLength]; maxPackedValue = new byte[packedIndexBytesLength]; - // If we may have more than 1+Integer.MAX_VALUE values, then we must encode ords with long (8 bytes), else we can use int (4 bytes). - this.longOrds = longOrds; - - this.singleValuePerDoc = singleValuePerDoc; - - // dimensional values (numDims * bytesPerDim) + ord (int or long) + docID (int) - if (singleValuePerDoc) { - // Lucene only supports up to 2.1 docs, so we better not need longOrds in this case: - assert longOrds == false; - bytesPerDoc = packedBytesLength + Integer.BYTES; - } else if (longOrds) { - bytesPerDoc = packedBytesLength + Long.BYTES + Integer.BYTES; - } else { - bytesPerDoc = packedBytesLength + Integer.BYTES + Integer.BYTES; - } + // dimensional values (numDims * bytesPerDim) + docID (int) + bytesPerDoc = packedBytesLength + Integer.BYTES; // As we recurse, we compute temporary partitions of the data, halving the // number of points at each recursion. Once there are few enough points, @@ -221,10 +185,10 @@ final class SimpleTextBKDWriter implements Closeable { // time in the recursion, we hold the number of points at that level, plus // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X // what that level would consume, so we multiply by 0.5 to convert from - // bytes to points here. Each dimension has its own sorted partition, so - // we must divide by numDims as wel. + // bytes to points here. In addition the radix partitioning may sort on memory + // double of this size so we multiply by another 0.5. - maxPointsSortInHeap = (int) (0.5 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims)); + maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims)); // Finally, we must be able to hold at least the leaf node in heap during build: if (maxPointsSortInHeap < maxPointsInLeafNode) { @@ -232,7 +196,7 @@ final class SimpleTextBKDWriter implements Closeable { } // We write first maxPointsSortInHeap in heap, then cutover to offline for additional points: - heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength, longOrds, singleValuePerDoc); + heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength); this.maxMBSortInHeap = maxMBSortInHeap; } @@ -264,13 +228,14 @@ final class SimpleTextBKDWriter implements Closeable { private void spillToOffline() throws IOException { // For each .add we just append to this input file, then in .finish we sort this input and resursively build the tree: - offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, longOrds, "spill", 0, singleValuePerDoc); + offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "spill", 0); tempInput = offlinePointWriter.out; PointReader reader = heapPointWriter.getReader(0, pointCount); for(int i=0;i= 0 && dim < numDataDims; - - if (heapPointWriter != null) { - - assert tempInput == null; - - // We never spilled the incoming points to disk, so now we sort in heap: - HeapPointWriter sorted; - - if (dim == 0) { - // First dim can re-use the current heap writer - sorted = heapPointWriter; - } else { - // Subsequent dims need a private copy - sorted = new HeapPointWriter((int) pointCount, (int) pointCount, packedBytesLength, longOrds, singleValuePerDoc); - sorted.copyFrom(heapPointWriter); - } - - //long t0 = System.nanoTime(); - sortHeapPointWriter(sorted, dim); - //long t1 = System.nanoTime(); - //System.out.println("BKD: sort took " + ((t1-t0)/1000000.0) + " msec"); - - sorted.close(); - return sorted; - } else { - - // Offline sort: - assert tempInput != null; - - final int offset = bytesPerDim * dim; - - Comparator cmp; - if (dim == numDataDims - 1) { - // in that case the bytes for the dimension and for the doc id are contiguous, - // so we don't need a branch - cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) { - @Override - protected int byteAt(BytesRef ref, int i) { - return ref.bytes[ref.offset + offset + i] & 0xff; - } - }; - } else { - cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) { - @Override - protected int byteAt(BytesRef ref, int i) { - if (i < bytesPerDim) { - return ref.bytes[ref.offset + offset + i] & 0xff; - } else { - return ref.bytes[ref.offset + packedBytesLength + i - bytesPerDim] & 0xff; - } - } - }; - } - - OfflineSorter sorter = new OfflineSorter(tempDir, tempFileNamePrefix + "_bkd" + dim, cmp, offlineSorterBufferMB, offlineSorterMaxTempFiles, bytesPerDoc, null, 0) { - - /** We write/read fixed-byte-width file that {@link OfflinePointReader} can read. */ - @Override - protected ByteSequencesWriter getWriter(IndexOutput out, long count) { - return new ByteSequencesWriter(out) { - @Override - public void write(byte[] bytes, int off, int len) throws IOException { - assert len == bytesPerDoc: "len=" + len + " bytesPerDoc=" + bytesPerDoc; - out.writeBytes(bytes, off, len); - } - }; - } - - /** We write/read fixed-byte-width file that {@link OfflinePointReader} can read. */ - @Override - protected ByteSequencesReader getReader(ChecksumIndexInput in, String name) throws IOException { - return new ByteSequencesReader(in, name) { - final BytesRef scratch = new BytesRef(new byte[bytesPerDoc]); - @Override - public BytesRef next() throws IOException { - if (in.getFilePointer() >= end) { - return null; - } - in.readBytes(scratch.bytes, 0, bytesPerDoc); - return scratch; - } - }; - } - }; - - String name = sorter.sort(tempInput.getName()); - - return new OfflinePointWriter(tempDir, name, packedBytesLength, pointCount, longOrds, singleValuePerDoc); - } - } - private void checkMaxLeafNodeCount(int numLeaves) { if ((1+bytesPerDim) * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) { throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex"); @@ -779,25 +625,21 @@ final class SimpleTextBKDWriter implements Closeable { throw new IllegalStateException("already finished"); } + PointWriter data; + if (offlinePointWriter != null) { offlinePointWriter.close(); + data = offlinePointWriter; + tempInput = null; + } else { + data = heapPointWriter; + heapPointWriter = null; } if (pointCount == 0) { throw new IllegalStateException("must index at least one point"); } - LongBitSet ordBitSet; - if (numDataDims > 1) { - if (singleValuePerDoc) { - ordBitSet = new LongBitSet(maxDoc); - } else { - ordBitSet = new LongBitSet(pointCount); - } - } else { - ordBitSet = null; - } - long countPerLeaf = pointCount; long innerNodeCount = 1; @@ -822,39 +664,17 @@ final class SimpleTextBKDWriter implements Closeable { // Make sure the math above "worked": assert pointCount / numLeaves <= maxPointsInLeafNode: "pointCount=" + pointCount + " numLeaves=" + numLeaves + " maxPointsInLeafNode=" + maxPointsInLeafNode; - // Sort all docs once by each dimension: - PathSlice[] sortedPointWriters = new PathSlice[numDataDims]; - - // This is only used on exception; on normal code paths we close all files we opened: - List toCloseHeroically = new ArrayList<>(); + //We re-use the selector so we do not need to create an object every time. + BKDRadixSelector radixSelector = new BKDRadixSelector(numDataDims, bytesPerDim, maxPointsSortInHeap, tempDir, tempFileNamePrefix); boolean success = false; try { - //long t0 = System.nanoTime(); - for(int dim=0;dim= the splitValue). */ - private byte[] markRightTree(long rightCount, int splitDim, PathSlice source, LongBitSet ordBitSet) throws IOException { - - // Now we mark ords that fall into the right half, so we can partition on all other dims that are not the split dim: - - // Read the split value, then mark all ords in the right tree (larger than the split value): - - // TODO: find a way to also checksum this reader? If we changed to markLeftTree, and scanned the final chunk, it could work? - try (PointReader reader = source.writer.getReader(source.start + source.count - rightCount, rightCount)) { - boolean result = reader.next(); - assert result; - System.arraycopy(reader.packedValue(), splitDim*bytesPerDim, scratch1, 0, bytesPerDim); - if (numDataDims > 1) { - assert ordBitSet.get(reader.ord()) == false; - ordBitSet.set(reader.ord()); - // Subtract 1 from rightCount because we already did the first value above (so we could record the split value): - reader.markOrds(rightCount-1, ordBitSet); - } - } catch (Throwable t) { - throw verifyChecksum(t, source.writer); - } - - return scratch1; - } - /** Called only in assert */ private boolean valueInBounds(BytesRef packedValue, byte[] minPackedValue, byte[] maxPackedValue) { for(int dim=0;dim toCloseHeroically) throws IOException { - int count = Math.toIntExact(source.count); + private HeapPointWriter switchToHeap(PointWriter source) throws IOException { + int count = Math.toIntExact(source.count()); // Not inside the try because we don't want to close it here: - PointReader reader = source.writer.getSharedReader(source.start, source.count, toCloseHeroically); - try (PointWriter writer = new HeapPointWriter(count, count, packedBytesLength, longOrds, singleValuePerDoc)) { + + try (PointReader reader = source.getReader(0, count); + HeapPointWriter writer = new HeapPointWriter(count, count, packedBytesLength)) { for(int i=0;i toCloseHeroically) throws IOException { - - for(PathSlice slice : slices) { - assert slice.count == slices[0].count; - } - - if (numDataDims == 1 && slices[0].writer instanceof OfflinePointWriter && slices[0].count <= maxPointsSortInHeap) { - // Special case for 1D, to cutover to heap once we recurse deeply enough: - slices[0] = switchToHeap(slices[0], toCloseHeroically); - } + long[] leafBlockFPs) throws IOException { if (nodeID >= leafNodeOffset) { // Leaf node: write block // We can write the block in any order so by default we write it sorted by the dimension that has the // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient + + if (data instanceof HeapPointWriter == false) { + // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started + // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer + data = switchToHeap(data); + } + + // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point: + HeapPointWriter heapSource = (HeapPointWriter) data; + + //we store common prefix on scratch1 + computeCommonPrefixLength(heapSource, scratch1); + int sortedDim = 0; int sortedDimCardinality = Integer.MAX_VALUE; - - for (int dim=0;dim= maxPointsInLeafNode, so we better be in heap at this point: - HeapPointWriter heapSource = (HeapPointWriter) source.writer; + sortHeapPointWriter(heapSource, sortedDim); // Save the block file pointer: leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer(); @@ -1320,9 +1076,9 @@ final class SimpleTextBKDWriter implements Closeable { // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o // loading the values: - int count = Math.toIntExact(source.count); + int count = Math.toIntExact(heapSource.count()); assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset; - writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(source.start), count); + writeLeafBlockDocs(out, heapSource.docIDs, 0, count); // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us // from the index, much like how terms dict does so from the FST: @@ -1337,12 +1093,12 @@ final class SimpleTextBKDWriter implements Closeable { @Override public BytesRef apply(int i) { - heapSource.getPackedValueSlice(Math.toIntExact(source.start + i), scratch); + heapSource.getPackedValueSlice(i, scratch); return scratch; } }; assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues, - heapSource.docIDs, Math.toIntExact(source.start)); + heapSource.docIDs, Math.toIntExact(0)); writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues); } else { @@ -1355,91 +1111,67 @@ final class SimpleTextBKDWriter implements Closeable { splitDim = 0; } - PathSlice source = slices[splitDim]; - assert nodeID < splitPackedValues.length: "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length; + assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length; // How many points will be in the left tree: - long rightCount = source.count / 2; - long leftCount = source.count - rightCount; + long rightCount = data.count() / 2; + long leftCount = data.count() - rightCount; - byte[] splitValue = markRightTree(rightCount, splitDim, source, ordBitSet); - int address = nodeID * (1+bytesPerDim); + PointWriter leftPointWriter; + PointWriter rightPointWriter; + byte[] splitValue; + + try (PointWriter leftPointWriter2 = getPointWriter(leftCount, "left" + splitDim); + PointWriter rightPointWriter2 = getPointWriter(rightCount, "right" + splitDim)) { + splitValue = radixSelector.select(data, leftPointWriter2, rightPointWriter2, 0, data.count(), leftCount, splitDim); + leftPointWriter = leftPointWriter2; + rightPointWriter = rightPointWriter2; + } catch (Throwable t) { + throw verifyChecksum(t, data); + } + + int address = nodeID * (1 + bytesPerDim); splitPackedValues[address] = (byte) splitDim; System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim); - // Partition all PathSlice that are not the split dim into sorted left and right sets, so we can recurse: - - PathSlice[] leftSlices = new PathSlice[numDataDims]; - PathSlice[] rightSlices = new PathSlice[numDataDims]; - byte[] minSplitPackedValue = new byte[packedIndexBytesLength]; System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, packedIndexBytesLength); byte[] maxSplitPackedValue = new byte[packedIndexBytesLength]; System.arraycopy(maxPackedValue, 0, maxSplitPackedValue, 0, packedIndexBytesLength); - // When we are on this dim, below, we clear the ordBitSet: - int dimToClear; - if (numDataDims - 1 == splitDim) { - dimToClear = numDataDims - 2; - } else { - dimToClear = numDataDims - 1; - } + System.arraycopy(splitValue, 0, minSplitPackedValue, splitDim * bytesPerDim, bytesPerDim); + System.arraycopy(splitValue, 0, maxSplitPackedValue, splitDim * bytesPerDim, bytesPerDim); - for(int dim=0;dim bkdReaders = new ArrayList<>(); List docMaps = new ArrayList<>(); for(int i=0;i= from"); + } + if (partitionPoint >= to) { + throw new IllegalArgumentException("partitionPoint must be < to"); + } + } + + private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim) throws IOException{ + //find common prefix + byte[] commonPrefix = new byte[bytesSorted]; + int commonPrefixPosition = bytesSorted; + try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { + reader.next(); + reader.packedValueWithDocId(bytesRef1); + // copy dimension + System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, commonPrefix, 0, bytesPerDim); + // copy docID + System.arraycopy(bytesRef1.bytes, bytesRef1.offset + packedBytesLength, commonPrefix, bytesPerDim, Integer.BYTES); + for (long i = from + 1; i< to; i++) { + reader.next(); + reader.packedValueWithDocId(bytesRef1); + int startIndex = dim * bytesPerDim; + int endIndex = (commonPrefixPosition > bytesPerDim) ? startIndex + bytesPerDim : startIndex + commonPrefixPosition; + int j = FutureArrays.mismatch(commonPrefix, 0, endIndex - startIndex, bytesRef1.bytes, bytesRef1.offset + startIndex, bytesRef1.offset + endIndex); + if (j == 0) { + return 0; + } else if (j == -1) { + if (commonPrefixPosition > bytesPerDim) { + //tie-break on docID + int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim ); + if (k != -1) { + commonPrefixPosition = bytesPerDim + k; + } + } + } else { + commonPrefixPosition = j; + } + } + } + + //build histogram up to the common prefix + for (int i = 0; i < commonPrefixPosition; i++) { + partitionBucket[i] = commonPrefix[i] & 0xff; + histogram[i][partitionBucket[i]] = to - from; + } + return commonPrefixPosition; + } + + private byte[] buildHistogramAndPartition(OfflinePointWriter points, PointWriter left, PointWriter right, + long from, long to, long partitionPoint, int iteration, int commonPrefix, int dim) throws IOException { + + long leftCount = 0; + long rightCount = 0; + //build histogram at the commonPrefix byte + try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { + while (reader.next()) { + reader.packedValueWithDocId(bytesRef1); + int bucket; + if (commonPrefix < bytesPerDim) { + bucket = bytesRef1.bytes[bytesRef1.offset + dim * bytesPerDim + commonPrefix] & 0xff; + } else { + bucket = bytesRef1.bytes[bytesRef1.offset + packedBytesLength + commonPrefix - bytesPerDim] & 0xff; + } + histogram[commonPrefix][bucket]++; + } + } + //Count left points and record the partition point + for(int i = 0; i < HISTOGRAM_SIZE; i++) { + long size = histogram[commonPrefix][i]; + if (leftCount + size > partitionPoint - from) { + partitionBucket[commonPrefix] = i; + break; + } + leftCount += size; + } + //Count right points + for(int i = partitionBucket[commonPrefix] + 1; i < HISTOGRAM_SIZE; i++) { + rightCount += histogram[commonPrefix][i]; + } + + long delta = histogram[commonPrefix][partitionBucket[commonPrefix]]; + assert leftCount + rightCount + delta == to - from; + + //special case when be have lot of points that are equal + if (commonPrefix == bytesSorted - 1) { + long tieBreakCount =(partitionPoint - from - leftCount); + partition(points, left, right, null, from, to, dim, commonPrefix, tieBreakCount); + return partitionPointFromCommonPrefix(); + } + + //create the delta points writer + PointWriter deltaPoints; + if (delta <= maxPointsSortInHeap) { + deltaPoints = new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength); + } else { + deltaPoints = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta); + } + //divide the points. This actually destroys the current writer + partition(points, left, right, deltaPoints, from, to, dim, commonPrefix, 0); + //close delta point writer + deltaPoints.close(); + + long newPartitionPoint = partitionPoint - from - leftCount; + + if (deltaPoints instanceof HeapPointWriter) { + return heapSelect((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix); + } else { + return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim); + } + } + + private void partition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints, + long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException { + assert bytePosition == bytesSorted -1 || deltaPoints != null; + long tiebreakCounter = 0; + try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { + while (reader.next()) { + reader.packedValueWithDocId(bytesRef1); + reader.packedValue(bytesRef2); + int docID = reader.docID(); + int bucket; + if (bytePosition < bytesPerDim) { + bucket = bytesRef1.bytes[bytesRef1.offset + dim * bytesPerDim + bytePosition] & 0xff; + } else { + bucket = bytesRef1.bytes[bytesRef1.offset + packedBytesLength + bytePosition - bytesPerDim] & 0xff; + } + //int bucket = getBucket(bytesRef1, dim, thisCommonPrefix); + if (bucket < this.partitionBucket[bytePosition]) { + // to the left side + left.append(bytesRef2, docID); + } else if (bucket > this.partitionBucket[bytePosition]) { + // to the right side + right.append(bytesRef2, docID); + } else { + if (bytePosition == bytesSorted - 1) { + if (tiebreakCounter < numDocsTiebreak) { + left.append(bytesRef2, docID); + tiebreakCounter++; + } else { + right.append(bytesRef2, docID); + } + } else { + deltaPoints.append(bytesRef2, docID); + } + } + } + } + //Delete original file + points.destroy(); + } + + private byte[] partitionPointFromCommonPrefix() { + byte[] partition = new byte[bytesPerDim]; + for (int i = 0; i < bytesPerDim; i++) { + partition[i] = (byte)partitionBucket[i]; + } + return partition; + } + + private byte[] heapSelect(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException { + final int offset = dim * bytesPerDim + commonPrefix; + new RadixSelector(bytesSorted - commonPrefix) { + + @Override + protected void swap(int i, int j) { + points.swap(i, j); + } + + @Override + protected int byteAt(int i, int k) { + assert k >= 0; + if (k + commonPrefix < bytesPerDim) { + // dim bytes + int block = i / points.valuesPerBlock; + int index = i % points.valuesPerBlock; + return points.blocks.get(block)[index * packedBytesLength + offset + k] & 0xff; + } else { + // doc id + int s = 3 - (k + commonPrefix - bytesPerDim); + return (points.docIDs[i] >>> (s * 8)) & 0xff; + } + } + }.select(from, to, partitionPoint); + + for (int i = from; i < to; i++) { + points.getPackedValueSlice(i, bytesRef1); + int docID = points.docIDs[i]; + if (i < partitionPoint) { + left.append(bytesRef1, docID); + } else { + right.append(bytesRef1, docID); + } + } + byte[] partition = new byte[bytesPerDim]; + points.getPackedValueSlice(partitionPoint, bytesRef1); + System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, partition, 0, bytesPerDim); + return partition; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java index 82c490ad646..a8ee7c5e399 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java @@ -20,7 +20,6 @@ import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; import java.util.List; import java.util.function.IntFunction; @@ -39,14 +38,11 @@ import org.apache.lucene.store.TrackingDirectoryWrapper; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; -import org.apache.lucene.util.BytesRefComparator; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FutureArrays; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.LongBitSet; import org.apache.lucene.util.MSBRadixSorter; import org.apache.lucene.util.NumericUtils; -import org.apache.lucene.util.OfflineSorter; import org.apache.lucene.util.PriorityQueue; // TODO @@ -59,7 +55,8 @@ import org.apache.lucene.util.PriorityQueue; // per leaf, and you can reduce that by putting more points per leaf // - we could use threads while building; the higher nodes are very parallelizable -/** Recursively builds a block KD-tree to assign all incoming points in N-dim space to smaller +/** + * Recursively builds a block KD-tree to assign all incoming points in N-dim space to smaller * and smaller N-dim rectangles (cells) until the number of points in a given * rectangle is <= maxPointsInLeafNode. The tree is * fully balanced, which means the leaf nodes will have between 50% and 100% of @@ -68,14 +65,13 @@ import org.apache.lucene.util.PriorityQueue; * *

The number of dimensions can be 1 to 8, but every byte[] value is fixed length. * - *

- * See this paper for details. - * - *

This consumes heap during writing: it allocates a LongBitSet(numPoints), - * and then uses up to the specified {@code maxMBSortInHeap} heap space for writing. + *

This consumes heap during writing: it allocates a Long[numLeaves], + * a byte[numLeaves*(1+bytesPerDim)] and then uses up to the specified + * {@code maxMBSortInHeap} heap space for writing. * *

- * NOTE: This can write at most Integer.MAX_VALUE * maxPointsInLeafNode total points. + * NOTE: This can write at most Integer.MAX_VALUE * maxPointsInLeafNode / (1+bytesPerDim) + * total points. * * @lucene.experimental */ @@ -143,32 +139,13 @@ public class BKDWriter implements Closeable { protected long pointCount; - /** true if we have so many values that we must write ords using long (8 bytes) instead of int (4 bytes) */ - protected final boolean longOrds; - /** An upper bound on how many points the caller will add (includes deletions) */ private final long totalPointCount; - /** True if every document has at most one value. We specialize this case by not bothering to store the ord since it's redundant with docID. */ - protected final boolean singleValuePerDoc; - - /** How much heap OfflineSorter is allowed to use */ - protected final OfflineSorter.BufferSize offlineSorterBufferMB; - - /** How much heap OfflineSorter is allowed to use */ - protected final int offlineSorterMaxTempFiles; - private final int maxDoc; public BKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim, - int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount, boolean singleValuePerDoc) throws IOException { - this(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount, singleValuePerDoc, - totalPointCount > Integer.MAX_VALUE, Math.max(1, (long) maxMBSortInHeap), OfflineSorter.MAX_TEMPFILES); - } - - protected BKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim, - int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount, - boolean singleValuePerDoc, boolean longOrds, long offlineSorterBufferMB, int offlineSorterMaxTempFiles) throws IOException { + int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount) throws IOException { verifyParams(numDataDims, numIndexDims, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount); // We use tracking dir to deal with removing files on exception, so each place that // creates temp files doesn't need crazy try/finally/sucess logic: @@ -180,8 +157,6 @@ public class BKDWriter implements Closeable { this.bytesPerDim = bytesPerDim; this.totalPointCount = totalPointCount; this.maxDoc = maxDoc; - this.offlineSorterBufferMB = OfflineSorter.BufferSize.megabytes(offlineSorterBufferMB); - this.offlineSorterMaxTempFiles = offlineSorterMaxTempFiles; docsSeen = new FixedBitSet(maxDoc); packedBytesLength = numDataDims * bytesPerDim; packedIndexBytesLength = numIndexDims * bytesPerDim; @@ -194,21 +169,9 @@ public class BKDWriter implements Closeable { minPackedValue = new byte[packedIndexBytesLength]; maxPackedValue = new byte[packedIndexBytesLength]; - // If we may have more than 1+Integer.MAX_VALUE values, then we must encode ords with long (8 bytes), else we can use int (4 bytes). - this.longOrds = longOrds; + // dimensional values (numDims * bytesPerDim) + docID (int) + bytesPerDoc = packedBytesLength + Integer.BYTES; - this.singleValuePerDoc = singleValuePerDoc; - - // dimensional values (numDims * bytesPerDim) + ord (int or long) + docID (int) - if (singleValuePerDoc) { - // Lucene only supports up to 2.1 docs, so we better not need longOrds in this case: - assert longOrds == false; - bytesPerDoc = packedBytesLength + Integer.BYTES; - } else if (longOrds) { - bytesPerDoc = packedBytesLength + Long.BYTES + Integer.BYTES; - } else { - bytesPerDoc = packedBytesLength + Integer.BYTES + Integer.BYTES; - } // As we recurse, we compute temporary partitions of the data, halving the // number of points at each recursion. Once there are few enough points, @@ -216,10 +179,10 @@ public class BKDWriter implements Closeable { // time in the recursion, we hold the number of points at that level, plus // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X // what that level would consume, so we multiply by 0.5 to convert from - // bytes to points here. Each dimension has its own sorted partition, so - // we must divide by numDims as wel. + // bytes to points here. In addition the radix partitioning may sort on memory + // double of this size so we multiply by another 0.5. - maxPointsSortInHeap = (int) (0.5 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims)); + maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc)); // Finally, we must be able to hold at least the leaf node in heap during build: if (maxPointsSortInHeap < maxPointsInLeafNode) { @@ -227,7 +190,7 @@ public class BKDWriter implements Closeable { } // We write first maxPointsSortInHeap in heap, then cutover to offline for additional points: - heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength, longOrds, singleValuePerDoc); + heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength); this.maxMBSortInHeap = maxMBSortInHeap; } @@ -259,15 +222,13 @@ public class BKDWriter implements Closeable { private void spillToOffline() throws IOException { // For each .add we just append to this input file, then in .finish we sort this input and resursively build the tree: - offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, longOrds, "spill", 0, singleValuePerDoc); + offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "spill", 0); tempInput = offlinePointWriter.out; - PointReader reader = heapPointWriter.getReader(0, pointCount); + scratchBytesRef1.length = packedBytesLength; for(int i=0;i= 0; - if (k < bytesPerDim) { + if (k + commonPrefixLength < bytesPerDim) { // dim bytes int block = i / writer.valuesPerBlock; int index = i % writer.valuesPerBlock; - return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k] & 0xff; + return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k + commonPrefixLength] & 0xff; } else { // doc id - int s = 3 - (k - bytesPerDim); + int s = 3 - (k + commonPrefixLength - bytesPerDim); return (writer.docIDs[i] >>> (s * 8)) & 0xff; } } @Override protected void swap(int i, int j) { - int docID = writer.docIDs[i]; - writer.docIDs[i] = writer.docIDs[j]; - writer.docIDs[j] = docID; - - if (singleValuePerDoc == false) { - if (longOrds) { - long ord = writer.ordsLong[i]; - writer.ordsLong[i] = writer.ordsLong[j]; - writer.ordsLong[j] = ord; - } else { - int ord = writer.ords[i]; - writer.ords[i] = writer.ords[j]; - writer.ords[j] = ord; - } - } - - byte[] blockI = writer.blocks.get(i / writer.valuesPerBlock); - int indexI = (i % writer.valuesPerBlock) * packedBytesLength; - byte[] blockJ = writer.blocks.get(j / writer.valuesPerBlock); - int indexJ = (j % writer.valuesPerBlock) * packedBytesLength; - - // scratch1 = values[i] - System.arraycopy(blockI, indexI, scratch1, 0, packedBytesLength); - // values[i] = values[j] - System.arraycopy(blockJ, indexJ, blockI, indexI, packedBytesLength); - // values[j] = scratch1 - System.arraycopy(scratch1, 0, blockJ, indexJ, packedBytesLength); + writer.swap(i, j); } }.sort(0, pointCount); @@ -835,134 +767,6 @@ public class BKDWriter implements Closeable { } */ - //return a new point writer sort by the provided dimension from input data - private PointWriter sort(int dim) throws IOException { - assert dim >= 0 && dim < numDataDims; - - if (heapPointWriter != null) { - assert tempInput == null; - // We never spilled the incoming points to disk, so now we sort in heap: - HeapPointWriter sorted = heapPointWriter; - //long t0 = System.nanoTime(); - sortHeapPointWriter(sorted, Math.toIntExact(this.pointCount), dim); - //long t1 = System.nanoTime(); - //System.out.println("BKD: sort took " + ((t1-t0)/1000000.0) + " msec"); - sorted.close(); - heapPointWriter = null; - return sorted; - } else { - // Offline sort: - assert tempInput != null; - OfflinePointWriter sorted = sortOffLine(dim, tempInput.getName(), 0, pointCount); - tempDir.deleteFile(tempInput.getName()); - tempInput = null; - return sorted; - } - } - - //return a new point writer sort by the provided dimension from start to start + pointCount - private PointWriter sort(int dim, PointWriter writer, final long start, final long pointCount) throws IOException { - assert dim >= 0 && dim < numDataDims; - - if (writer instanceof HeapPointWriter) { - HeapPointWriter heapPointWriter = createHeapPointWriterCopy((HeapPointWriter) writer, start, pointCount); - sortHeapPointWriter(heapPointWriter, Math.toIntExact(pointCount), dim); - return heapPointWriter; - } else { - OfflinePointWriter offlinePointWriter = (OfflinePointWriter) writer; - return sortOffLine(dim, offlinePointWriter.name, start, pointCount); - } - } - - // sort a given file on a given dimension for start to start + point count - private OfflinePointWriter sortOffLine(int dim, String inputName, final long start, final long pointCount) throws IOException { - - final int offset = bytesPerDim * dim; - - Comparator cmp; - if (dim == numDataDims - 1) { - // in that case the bytes for the dimension and for the doc id are contiguous, - // so we don't need a branch - cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) { - @Override - protected int byteAt(BytesRef ref, int i) { - return ref.bytes[ref.offset + offset + i] & 0xff; - } - }; - } else { - cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) { - @Override - protected int byteAt(BytesRef ref, int i) { - if (i < bytesPerDim) { - return ref.bytes[ref.offset + offset + i] & 0xff; - } else { - return ref.bytes[ref.offset + packedBytesLength + i - bytesPerDim] & 0xff; - } - } - }; - } - - OfflineSorter sorter = new OfflineSorter(tempDir, tempFileNamePrefix + "_bkd" + dim, cmp, offlineSorterBufferMB, offlineSorterMaxTempFiles, bytesPerDoc, null, 0) { - /** - * We write/read fixed-byte-width file that {@link OfflinePointReader} can read. - */ - @Override - protected ByteSequencesWriter getWriter(IndexOutput out, long count) { - return new ByteSequencesWriter(out) { - @Override - public void write(byte[] bytes, int off, int len) throws IOException { - assert len == bytesPerDoc : "len=" + len + " bytesPerDoc=" + bytesPerDoc; - out.writeBytes(bytes, off, len); - } - }; - } - - /** - * We write/read fixed-byte-width file that {@link OfflinePointReader} can read. - */ - @Override - protected ByteSequencesReader getReader(ChecksumIndexInput in, String name) throws IOException { - //This allows to read only a subset of the original file - long startPointer = (name.equals(inputName)) ? bytesPerDoc * start : in.getFilePointer(); - long endPointer = (name.equals(inputName)) ? startPointer + bytesPerDoc * pointCount : Long.MAX_VALUE; - in.seek(startPointer); - return new ByteSequencesReader(in, name) { - final BytesRef scratch = new BytesRef(new byte[bytesPerDoc]); - - @Override - public BytesRef next() throws IOException { - if (in.getFilePointer() >= end) { - return null; - } else if (in.getFilePointer() >= endPointer) { - in.seek(end); - return null; - } - in.readBytes(scratch.bytes, 0, bytesPerDoc); - return scratch; - } - }; - } - }; - - String name = sorter.sort(inputName); - return new OfflinePointWriter(tempDir, name, packedBytesLength, pointCount, longOrds, singleValuePerDoc); - } - - private HeapPointWriter createHeapPointWriterCopy(HeapPointWriter writer, long start, long count) throws IOException { - //TODO: Can we do this faster? - int size = Math.toIntExact(count); - try (HeapPointWriter pointWriter = new HeapPointWriter(size, size, packedBytesLength, longOrds, singleValuePerDoc); - PointReader reader = writer.getReader(start, count)) { - for (long i =0; i < count; i++) { - reader.next(); - pointWriter.append(reader.packedValue(), reader.ord(), reader.docID()); - } - return pointWriter; - } catch (Throwable t) { - throw verifyChecksum(t, writer); - } - } - private void checkMaxLeafNodeCount(int numLeaves) { if ((1+bytesPerDim) * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) { throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex"); @@ -980,25 +784,20 @@ public class BKDWriter implements Closeable { throw new IllegalStateException("already finished"); } + PointWriter writer; if (offlinePointWriter != null) { offlinePointWriter.close(); + writer = offlinePointWriter; + tempInput = null; + } else { + writer = heapPointWriter; + heapPointWriter = null; } if (pointCount == 0) { throw new IllegalStateException("must index at least one point"); } - LongBitSet ordBitSet; - if (numIndexDims > 1) { - if (singleValuePerDoc) { - ordBitSet = new LongBitSet(maxDoc); - } else { - ordBitSet = new LongBitSet(pointCount); - } - } else { - ordBitSet = null; - } - long countPerLeaf = pointCount; long innerNodeCount = 1; @@ -1023,23 +822,19 @@ public class BKDWriter implements Closeable { // Make sure the math above "worked": assert pointCount / numLeaves <= maxPointsInLeafNode: "pointCount=" + pointCount + " numLeaves=" + numLeaves + " maxPointsInLeafNode=" + maxPointsInLeafNode; - // Slices are created as they are needed - PathSlice[] sortedPointWriters = new PathSlice[numIndexDims]; - - // This is only used on exception; on normal code paths we close all files we opened: - List toCloseHeroically = new ArrayList<>(); + //We re-use the selector so we do not need to create an object every time. + BKDRadixSelector radixSelector = new BKDRadixSelector(numDataDims, bytesPerDim, maxPointsSortInHeap, tempDir, tempFileNamePrefix); boolean success = false; try { final int[] parentSplits = new int[numIndexDims]; - build(1, numLeaves, sortedPointWriters, - ordBitSet, out, + build(1, numLeaves, writer, + out, radixSelector, minPackedValue, maxPackedValue, parentSplits, splitPackedValues, - leafBlockFPs, - toCloseHeroically); + leafBlockFPs); assert Arrays.equals(parentSplits, new int[numIndexDims]); // If no exception, we should have cleaned everything up: @@ -1051,7 +846,6 @@ public class BKDWriter implements Closeable { } finally { if (success == false) { IOUtils.deleteFilesIgnoringExceptions(tempDir, tempDir.getCreatedFiles()); - IOUtils.closeWhileHandlingException(toCloseHeroically); } } @@ -1243,7 +1037,6 @@ public class BKDWriter implements Closeable { } private long getLeftMostLeafBlockFP(long[] leafBlockFPs, int nodeID) { - int nodeIDIn = nodeID; // TODO: can we do this cheaper, e.g. a closed form solution instead of while loop? Or // change the recursion while packing the index to return this left-most leaf block FP // from each recursion instead? @@ -1400,24 +1193,6 @@ public class BKDWriter implements Closeable { } } - /** Sliced reference to points in an OfflineSorter.ByteSequencesWriter file. */ - private static final class PathSlice { - final PointWriter writer; - final long start; - final long count; - - public PathSlice(PointWriter writer, long start, long count) { - this.writer = writer; - this.start = start; - this.count = count; - } - - @Override - public String toString() { - return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")"; - } - } - /** Called on exception, to check whether the checksum is also corrupt in this source, and add that * information (checksum matched or didn't) as a suppressed exception. */ private Error verifyChecksum(Throwable priorException, PointWriter writer) throws IOException { @@ -1430,8 +1205,10 @@ public class BKDWriter implements Closeable { if (writer instanceof OfflinePointWriter) { // We are reading from a temp file; go verify the checksum: String tempFileName = ((OfflinePointWriter) writer).name; - try (ChecksumIndexInput in = tempDir.openChecksumInput(tempFileName, IOContext.READONCE)) { - CodecUtil.checkFooter(in, priorException); + if (tempDir.getCreatedFiles().contains(tempFileName)) { + try (ChecksumIndexInput in = tempDir.openChecksumInput(tempFileName, IOContext.READONCE)) { + CodecUtil.checkFooter(in, priorException); + } } } @@ -1439,31 +1216,6 @@ public class BKDWriter implements Closeable { throw IOUtils.rethrowAlways(priorException); } - /** Marks bits for the ords (points) that belong in the right sub tree (those docs that have values >= the splitValue). */ - private byte[] markRightTree(long rightCount, int splitDim, PathSlice source, LongBitSet ordBitSet) throws IOException { - - // Now we mark ords that fall into the right half, so we can partition on all other dims that are not the split dim: - - // Read the split value, then mark all ords in the right tree (larger than the split value): - - // TODO: find a way to also checksum this reader? If we changed to markLeftTree, and scanned the final chunk, it could work? - try (PointReader reader = source.writer.getReader(source.start + source.count - rightCount, rightCount)) { - boolean result = reader.next(); - assert result: "rightCount=" + rightCount + " source.count=" + source.count + " source.writer=" + source.writer; - System.arraycopy(reader.packedValue(), splitDim*bytesPerDim, scratch1, 0, bytesPerDim); - if (numIndexDims > 1 && ordBitSet != null) { - assert ordBitSet.get(reader.ord()) == false; - ordBitSet.set(reader.ord()); - // Subtract 1 from rightCount because we already did the first value above (so we could record the split value): - reader.markOrds(rightCount-1, ordBitSet); - } - } catch (Throwable t) { - throw verifyChecksum(t, source.writer); - } - - return scratch1; - } - /** Called only in assert */ private boolean valueInBounds(BytesRef packedValue, byte[] minPackedValue, byte[] maxPackedValue) { for(int dim=0;dim toCloseHeroically) throws IOException { - int count = Math.toIntExact(source.count); + private HeapPointWriter switchToHeap(PointWriter source) throws IOException { + int count = Math.toIntExact(source.count()); // Not inside the try because we don't want to close it here: - PointReader reader = source.writer.getSharedReader(source.start, source.count, toCloseHeroically); - try (PointWriter writer = new HeapPointWriter(count, count, packedBytesLength, longOrds, singleValuePerDoc)) { + + try (PointReader reader = source.getReader(0, source.count()); + HeapPointWriter writer = new HeapPointWriter(count, count, packedBytesLength)) { for(int i=0;i 1 case. */ private void build(int nodeID, int leafNodeOffset, - PathSlice[] slices, - LongBitSet ordBitSet, + PointWriter points, IndexOutput out, + BKDRadixSelector radixSelector, byte[] minPackedValue, byte[] maxPackedValue, int[] parentSplits, byte[] splitPackedValues, - long[] leafBlockFPs, - List toCloseHeroically) throws IOException { + long[] leafBlockFPs) throws IOException { if (nodeID >= leafNodeOffset) { // Leaf node: write block // We can write the block in any order so by default we write it sorted by the dimension that has the // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient + + if (points instanceof HeapPointWriter == false) { + // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started + // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer + points = switchToHeap(points); + } + + // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point: + HeapPointWriter heapSource = (HeapPointWriter) points; + + //we store common prefix on scratch1 + computeCommonPrefixLength(heapSource, scratch1); + int sortedDim = 0; int sortedDimCardinality = Integer.MAX_VALUE; - - for (int dim=0;dim= maxPointsInLeafNode, so we better be in heap at this point: - HeapPointWriter heapSource = (HeapPointWriter) source.writer; + // sort the chosen dimension + sortHeapPointWriter(heapSource, Math.toIntExact(heapSource.count()), sortedDim, commonPrefixLengths[sortedDim]); // Save the block file pointer: leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer(); @@ -1788,9 +1490,9 @@ public class BKDWriter implements Closeable { // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o // loading the values: - int count = Math.toIntExact(source.count); + int count = Math.toIntExact(heapSource.count()); assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset; - writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(source.start), count); + writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(0), count); // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us // from the index, much like how terms dict does so from the FST: @@ -1808,12 +1510,12 @@ public class BKDWriter implements Closeable { @Override public BytesRef apply(int i) { - heapSource.getPackedValueSlice(Math.toIntExact(source.start + i), scratch); + heapSource.getPackedValueSlice(Math.toIntExact(i), scratch); return scratch; } }; assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues, - heapSource.docIDs, Math.toIntExact(source.start)); + heapSource.docIDs, Math.toIntExact(0)); writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues); } else { @@ -1826,124 +1528,70 @@ public class BKDWriter implements Closeable { splitDim = 0; } - //We delete the created path slices at the same level - boolean deleteSplitDim = false; - if (slices[splitDim] == null) { - createPathSlice(slices, splitDim); - deleteSplitDim = true; - } - PathSlice source = slices[splitDim]; - assert nodeID < splitPackedValues.length: "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length; + assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length; // How many points will be in the left tree: - long rightCount = source.count / 2; - long leftCount = source.count - rightCount; + long rightCount = points.count() / 2; + long leftCount = points.count() - rightCount; - // When we are on this dim, below, we clear the ordBitSet: - int dimToClear = numIndexDims - 1; - while (dimToClear >= 0) { - if (slices[dimToClear] != null && splitDim != dimToClear) { - break; - } - dimToClear--; + PointWriter leftPointWriter; + PointWriter rightPointWriter; + byte[] splitValue; + try (PointWriter tempLeftPointWriter = getPointWriter(leftCount, "left" + splitDim); + PointWriter tempRightPointWriter = getPointWriter(rightCount, "right" + splitDim)) { + splitValue = radixSelector.select(points, tempLeftPointWriter, tempRightPointWriter, 0, points.count(), leftCount, splitDim); + leftPointWriter = tempLeftPointWriter; + rightPointWriter = tempRightPointWriter; + } catch (Throwable t) { + throw verifyChecksum(t, points); } - byte[] splitValue = (dimToClear == -1) ? markRightTree(rightCount, splitDim, source, null) : markRightTree(rightCount, splitDim, source, ordBitSet); - int address = nodeID * (1+bytesPerDim); + int address = nodeID * (1 + bytesPerDim); splitPackedValues[address] = (byte) splitDim; System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim); - // Partition all PathSlice that are not the split dim into sorted left and right sets, so we can recurse: - - PathSlice[] leftSlices = new PathSlice[numIndexDims]; - PathSlice[] rightSlices = new PathSlice[numIndexDims]; - byte[] minSplitPackedValue = new byte[packedIndexBytesLength]; System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, packedIndexBytesLength); byte[] maxSplitPackedValue = new byte[packedIndexBytesLength]; System.arraycopy(maxPackedValue, 0, maxSplitPackedValue, 0, packedIndexBytesLength); - - for(int dim=0;dim blocks; final int valuesPerBlock; final int packedBytesLength; - final long[] ordsLong; - final int[] ords; final int[] docIDs; final int end; - final byte[] scratch; - final boolean singleValuePerDoc; - public HeapPointReader(List blocks, int valuesPerBlock, int packedBytesLength, int[] ords, long[] ordsLong, int[] docIDs, int start, int end, boolean singleValuePerDoc) { + public HeapPointReader(List blocks, int valuesPerBlock, int packedBytesLength, int[] docIDs, int start, int end) { this.blocks = blocks; this.valuesPerBlock = valuesPerBlock; - this.singleValuePerDoc = singleValuePerDoc; - this.ords = ords; - this.ordsLong = ordsLong; this.docIDs = docIDs; curRead = start-1; this.end = end; this.packedBytesLength = packedBytesLength; - scratch = new byte[packedBytesLength]; - } - - void writePackedValue(int index, byte[] bytes) { - int block = index / valuesPerBlock; - int blockIndex = index % valuesPerBlock; - while (blocks.size() <= block) { - blocks.add(new byte[valuesPerBlock*packedBytesLength]); - } - System.arraycopy(bytes, 0, blocks.get(blockIndex), blockIndex * packedBytesLength, packedBytesLength); - } - - void readPackedValue(int index, byte[] bytes) { - int block = index / valuesPerBlock; - int blockIndex = index % valuesPerBlock; - System.arraycopy(blocks.get(block), blockIndex * packedBytesLength, bytes, 0, packedBytesLength); } @Override @@ -68,9 +49,12 @@ public final class HeapPointReader extends PointReader { } @Override - public byte[] packedValue() { - readPackedValue(curRead, scratch); - return scratch; + public void packedValue(BytesRef bytesRef) { + int block = curRead / valuesPerBlock; + int blockIndex = curRead % valuesPerBlock; + bytesRef.bytes = blocks.get(block); + bytesRef.offset = blockIndex * packedBytesLength; + bytesRef.length = packedBytesLength; } @Override @@ -78,17 +62,6 @@ public final class HeapPointReader extends PointReader { return docIDs[curRead]; } - @Override - public long ord() { - if (singleValuePerDoc) { - return docIDs[curRead]; - } else if (ordsLong != null) { - return ordsLong[curRead]; - } else { - return ords[curRead]; - } - } - @Override public void close() { } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java index eb1d48b9f12..0e4ad782f5f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java @@ -16,46 +16,36 @@ */ package org.apache.lucene.util.bkd; -import java.io.Closeable; import java.util.ArrayList; import java.util.List; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; -/** Utility class to write new points into in-heap arrays. +/** + * Utility class to write new points into in-heap arrays. * - * @lucene.internal */ + * @lucene.internal + * */ public final class HeapPointWriter implements PointWriter { public int[] docIDs; - public long[] ordsLong; - public int[] ords; private int nextWrite; private boolean closed; final int maxSize; public final int valuesPerBlock; final int packedBytesLength; - final boolean singleValuePerDoc; // NOTE: can't use ByteBlockPool because we need random-write access when sorting in heap public final List blocks = new ArrayList<>(); + private byte[] scratch; - public HeapPointWriter(int initSize, int maxSize, int packedBytesLength, boolean longOrds, boolean singleValuePerDoc) { + + public HeapPointWriter(int initSize, int maxSize, int packedBytesLength) { docIDs = new int[initSize]; this.maxSize = maxSize; this.packedBytesLength = packedBytesLength; - this.singleValuePerDoc = singleValuePerDoc; - if (singleValuePerDoc) { - this.ordsLong = null; - this.ords = null; - } else { - if (longOrds) { - this.ordsLong = new long[initSize]; - } else { - this.ords = new int[initSize]; - } - } // 4K per page, unless each value is > 4K: valuesPerBlock = Math.max(1, 4096/packedBytesLength); + scratch = new byte[packedBytesLength]; } public void copyFrom(HeapPointWriter other) { @@ -63,36 +53,19 @@ public final class HeapPointWriter implements PointWriter { throw new IllegalStateException("docIDs.length=" + docIDs.length + " other.nextWrite=" + other.nextWrite); } System.arraycopy(other.docIDs, 0, docIDs, 0, other.nextWrite); - if (singleValuePerDoc == false) { - if (other.ords != null) { - assert this.ords != null; - System.arraycopy(other.ords, 0, ords, 0, other.nextWrite); - } else { - assert this.ordsLong != null; - System.arraycopy(other.ordsLong, 0, ordsLong, 0, other.nextWrite); - } - } - for(byte[] block : other.blocks) { blocks.add(block.clone()); } nextWrite = other.nextWrite; } - public void readPackedValue(int index, byte[] bytes) { - assert bytes.length == packedBytesLength; - int block = index / valuesPerBlock; - int blockIndex = index % valuesPerBlock; - System.arraycopy(blocks.get(block), blockIndex * packedBytesLength, bytes, 0, packedBytesLength); - } - /** Returns a reference, in result, to the byte[] slice holding this value */ public void getPackedValueSlice(int index, BytesRef result) { int block = index / valuesPerBlock; int blockIndex = index % valuesPerBlock; result.bytes = blocks.get(block); result.offset = blockIndex * packedBytesLength; - assert result.length == packedBytesLength; + result.length = packedBytesLength; } void writePackedValue(int index, byte[] bytes) { @@ -108,45 +81,76 @@ public final class HeapPointWriter implements PointWriter { System.arraycopy(bytes, 0, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength); } + void writePackedValue(int index, BytesRef bytes) { + assert bytes.length == packedBytesLength; + int block = index / valuesPerBlock; + int blockIndex = index % valuesPerBlock; + //System.out.println("writePackedValue: index=" + index + " bytes.length=" + bytes.length + " block=" + block + " blockIndex=" + blockIndex + " valuesPerBlock=" + valuesPerBlock); + while (blocks.size() <= block) { + // If this is the last block, only allocate as large as necessary for maxSize: + int valuesInBlock = Math.min(valuesPerBlock, maxSize - (blocks.size() * valuesPerBlock)); + blocks.add(new byte[valuesInBlock*packedBytesLength]); + } + System.arraycopy(bytes.bytes, bytes.offset, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength); + } + @Override - public void append(byte[] packedValue, long ord, int docID) { + public void append(byte[] packedValue, int docID) { assert closed == false; assert packedValue.length == packedBytesLength; if (docIDs.length == nextWrite) { int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, Integer.BYTES)); assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite; docIDs = ArrayUtil.growExact(docIDs, nextSize); - if (singleValuePerDoc == false) { - if (ordsLong != null) { - ordsLong = ArrayUtil.growExact(ordsLong, nextSize); - } else { - ords = ArrayUtil.growExact(ords, nextSize); - } - } } writePackedValue(nextWrite, packedValue); - if (singleValuePerDoc == false) { - if (ordsLong != null) { - ordsLong[nextWrite] = ord; - } else { - assert ord <= Integer.MAX_VALUE; - ords[nextWrite] = (int) ord; - } - } docIDs[nextWrite] = docID; nextWrite++; } + @Override + public void append(BytesRef packedValue, int docID) { + assert closed == false; + assert packedValue.length == packedBytesLength; + if (docIDs.length == nextWrite) { + int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, Integer.BYTES)); + assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite; + docIDs = ArrayUtil.growExact(docIDs, nextSize); + } + writePackedValue(nextWrite, packedValue); + docIDs[nextWrite] = docID; + nextWrite++; + } + + public void swap(int i, int j) { + int docID = docIDs[i]; + docIDs[i] = docIDs[j]; + docIDs[j] = docID; + + + byte[] blockI = blocks.get(i / valuesPerBlock); + int indexI = (i % valuesPerBlock) * packedBytesLength; + byte[] blockJ = blocks.get(j / valuesPerBlock); + int indexJ = (j % valuesPerBlock) * packedBytesLength; + + // scratch1 = values[i] + System.arraycopy(blockI, indexI, scratch, 0, packedBytesLength); + // values[i] = values[j] + System.arraycopy(blockJ, indexJ, blockI, indexI, packedBytesLength); + // values[j] = scratch1 + System.arraycopy(scratch, 0, blockJ, indexJ, packedBytesLength); + } + + @Override + public long count() { + return nextWrite; + } + @Override public PointReader getReader(long start, long length) { assert start + length <= docIDs.length: "start=" + start + " length=" + length + " docIDs.length=" + docIDs.length; assert start + length <= nextWrite: "start=" + start + " length=" + length + " nextWrite=" + nextWrite; - return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, ords, ordsLong, docIDs, (int) start, Math.toIntExact(start+length), singleValuePerDoc); - } - - @Override - public PointReader getSharedReader(long start, long length, List toCloseHeroically) { - return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, ords, ordsLong, docIDs, (int) start, nextWrite, singleValuePerDoc); + return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, docIDs, (int) start, Math.toIntExact(start+length)); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java index 2861d593deb..86afc790c62 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java @@ -24,44 +24,42 @@ import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.LongBitSet; +import org.apache.lucene.util.BytesRef; -/** Reads points from disk in a fixed-with format, previously written with {@link OfflinePointWriter}. +/** + * Reads points from disk in a fixed-with format, previously written with {@link OfflinePointWriter}. * - * @lucene.internal */ + * @lucene.internal + * */ public final class OfflinePointReader extends PointReader { + long countLeft; final IndexInput in; - private final byte[] packedValue; - final boolean singleValuePerDoc; + byte[] onHeapBuffer; + private int offset; final int bytesPerDoc; - private long ord; - private int docID; - // true if ords are written as long (8 bytes), else 4 bytes - private boolean longOrds; private boolean checked; - + private final int packedValueLength; + private int pointsInBuffer; + private final int maxPointOnHeap; // File name we are reading final String name; - public OfflinePointReader(Directory tempDir, String tempFileName, int packedBytesLength, long start, long length, - boolean longOrds, boolean singleValuePerDoc) throws IOException { - this.singleValuePerDoc = singleValuePerDoc; - int bytesPerDoc = packedBytesLength + Integer.BYTES; - if (singleValuePerDoc == false) { - if (longOrds) { - bytesPerDoc += Long.BYTES; - } else { - bytesPerDoc += Integer.BYTES; - } - } - this.bytesPerDoc = bytesPerDoc; + public OfflinePointReader(Directory tempDir, String tempFileName, int packedBytesLength, long start, long length, byte[] reusableBuffer) throws IOException { + this.bytesPerDoc = packedBytesLength + Integer.BYTES; + this.packedValueLength = packedBytesLength; if ((start + length) * bytesPerDoc + CodecUtil.footerLength() > tempDir.fileLength(tempFileName)) { throw new IllegalArgumentException("requested slice is beyond the length of this file: start=" + start + " length=" + length + " bytesPerDoc=" + bytesPerDoc + " fileLength=" + tempDir.fileLength(tempFileName) + " tempFileName=" + tempFileName); } + if (reusableBuffer == null) { + throw new IllegalArgumentException("[reusableBuffer] cannot be null"); + } + if (reusableBuffer.length < bytesPerDoc) { + throw new IllegalArgumentException("Length of [reusableBuffer] must be bigger than " + bytesPerDoc); + } + this.maxPointOnHeap = reusableBuffer.length / bytesPerDoc; // Best-effort checksumming: if (start == 0 && length*bytesPerDoc == tempDir.fileLength(tempFileName) - CodecUtil.footerLength()) { // If we are going to read the entire file, e.g. because BKDWriter is now @@ -74,55 +72,63 @@ public final class OfflinePointReader extends PointReader { // at another level of the BKDWriter recursion in = tempDir.openInput(tempFileName, IOContext.READONCE); } + name = tempFileName; long seekFP = start * bytesPerDoc; in.seek(seekFP); countLeft = length; - packedValue = new byte[packedBytesLength]; - this.longOrds = longOrds; + this.onHeapBuffer = reusableBuffer; } @Override public boolean next() throws IOException { - if (countLeft >= 0) { - if (countLeft == 0) { + if (this.pointsInBuffer == 0) { + if (countLeft >= 0) { + if (countLeft == 0) { + return false; + } + } + try { + if (countLeft > maxPointOnHeap) { + in.readBytes(onHeapBuffer, 0, maxPointOnHeap * bytesPerDoc); + pointsInBuffer = maxPointOnHeap - 1; + countLeft -= maxPointOnHeap; + } else { + in.readBytes(onHeapBuffer, 0, (int) countLeft * bytesPerDoc); + pointsInBuffer = Math.toIntExact(countLeft - 1); + countLeft = 0; + } + this.offset = 0; + } catch (EOFException eofe) { + assert countLeft == -1; return false; } - countLeft--; - } - try { - in.readBytes(packedValue, 0, packedValue.length); - } catch (EOFException eofe) { - assert countLeft == -1; - return false; - } - docID = in.readInt(); - if (singleValuePerDoc == false) { - if (longOrds) { - ord = in.readLong(); - } else { - ord = in.readInt(); - } } else { - ord = docID; + this.pointsInBuffer--; + this.offset += bytesPerDoc; } return true; } @Override - public byte[] packedValue() { - return packedValue; + public void packedValue(BytesRef bytesRef) { + bytesRef.bytes = onHeapBuffer; + bytesRef.offset = offset; + bytesRef.length = packedValueLength; } - @Override - public long ord() { - return ord; + protected void packedValueWithDocId(BytesRef bytesRef) { + bytesRef.bytes = onHeapBuffer; + bytesRef.offset = offset; + bytesRef.length = bytesPerDoc; } @Override public int docID() { - return docID; + int position = this.offset + packedValueLength; + return ((onHeapBuffer[position++] & 0xFF) << 24) | ((onHeapBuffer[position++] & 0xFF) << 16) + | ((onHeapBuffer[position++] & 0xFF) << 8) | (onHeapBuffer[position++] & 0xFF); } @Override @@ -137,112 +143,5 @@ public final class OfflinePointReader extends PointReader { in.close(); } } - - @Override - public void markOrds(long count, LongBitSet ordBitSet) throws IOException { - if (countLeft < count) { - throw new IllegalStateException("only " + countLeft + " points remain, but " + count + " were requested"); - } - long fp = in.getFilePointer() + packedValue.length; - if (singleValuePerDoc == false) { - fp += Integer.BYTES; - } - for(long i=0;i offline split since the default impl - // is somewhat wasteful otherwise (e.g. decoding docID when we don't - // need to) - - int packedBytesLength = packedValue.length; - - int bytesPerDoc = packedBytesLength + Integer.BYTES; - if (singleValuePerDoc == false) { - if (longOrds) { - bytesPerDoc += Long.BYTES; - } else { - bytesPerDoc += Integer.BYTES; - } - } - - long rightCount = 0; - - IndexOutput rightOut = ((OfflinePointWriter) right).out; - IndexOutput leftOut = ((OfflinePointWriter) left).out; - - assert count <= countLeft: "count=" + count + " countLeft=" + countLeft; - - countLeft -= count; - - long countStart = count; - - byte[] buffer = new byte[bytesPerDoc]; - while (count > 0) { - in.readBytes(buffer, 0, buffer.length); - - long ord; - if (longOrds) { - // A long ord, after the docID: - ord = readLong(buffer, packedBytesLength+Integer.BYTES); - } else if (singleValuePerDoc) { - // docID is the ord: - ord = readInt(buffer, packedBytesLength); - } else { - // An int ord, after the docID: - ord = readInt(buffer, packedBytesLength+Integer.BYTES); - } - - if (rightTree.get(ord)) { - rightOut.writeBytes(buffer, 0, bytesPerDoc); - if (doClearBits) { - rightTree.clear(ord); - } - rightCount++; - } else { - leftOut.writeBytes(buffer, 0, bytesPerDoc); - } - - count--; - } - - ((OfflinePointWriter) right).count = rightCount; - ((OfflinePointWriter) left).count = countStart-rightCount; - - return rightCount; - } - - // Poached from ByteArrayDataInput: - private static long readLong(byte[] bytes, int pos) { - final int i1 = ((bytes[pos++] & 0xff) << 24) | ((bytes[pos++] & 0xff) << 16) | - ((bytes[pos++] & 0xff) << 8) | (bytes[pos++] & 0xff); - final int i2 = ((bytes[pos++] & 0xff) << 24) | ((bytes[pos++] & 0xff) << 16) | - ((bytes[pos++] & 0xff) << 8) | (bytes[pos++] & 0xff); - return (((long)i1) << 32) | (i2 & 0xFFFFFFFFL); - } - - // Poached from ByteArrayDataInput: - private static int readInt(byte[] bytes, int pos) { - return ((bytes[pos++] & 0xFF) << 24) | ((bytes[pos++] & 0xFF) << 16) - | ((bytes[pos++] & 0xFF) << 8) | (bytes[pos++] & 0xFF); - } } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java index 7e615a657ec..5479b531dbe 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java @@ -16,104 +16,79 @@ */ package org.apache.lucene.util.bkd; -import java.io.Closeable; import java.io.IOException; -import java.util.List; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.BytesRef; -/** Writes points to disk in a fixed-with format. +/** + * Writes points to disk in a fixed-with format. * - * @lucene.internal */ + * @lucene.internal + * */ public final class OfflinePointWriter implements PointWriter { final Directory tempDir; public final IndexOutput out; public final String name; final int packedBytesLength; - final boolean singleValuePerDoc; long count; private boolean closed; - // true if ords are written as long (8 bytes), else 4 bytes - private boolean longOrds; - private OfflinePointReader sharedReader; - private long nextSharedRead; final long expectedCount; /** Create a new writer with an unknown number of incoming points */ public OfflinePointWriter(Directory tempDir, String tempFileNamePrefix, int packedBytesLength, - boolean longOrds, String desc, long expectedCount, boolean singleValuePerDoc) throws IOException { + String desc, long expectedCount) throws IOException { this.out = tempDir.createTempOutput(tempFileNamePrefix, "bkd_" + desc, IOContext.DEFAULT); this.name = out.getName(); this.tempDir = tempDir; this.packedBytesLength = packedBytesLength; - this.longOrds = longOrds; - this.singleValuePerDoc = singleValuePerDoc; - this.expectedCount = expectedCount; - } - /** Initializes on an already written/closed file, just so consumers can use {@link #getReader} to read the file. */ - public OfflinePointWriter(Directory tempDir, String name, int packedBytesLength, long count, boolean longOrds, boolean singleValuePerDoc) { - this.out = null; - this.name = name; - this.tempDir = tempDir; - this.packedBytesLength = packedBytesLength; - this.count = count; - closed = true; - this.longOrds = longOrds; - this.singleValuePerDoc = singleValuePerDoc; - this.expectedCount = 0; + this.expectedCount = expectedCount; } @Override - public void append(byte[] packedValue, long ord, int docID) throws IOException { + public void append(byte[] packedValue, int docID) throws IOException { assert packedValue.length == packedBytesLength; out.writeBytes(packedValue, 0, packedValue.length); out.writeInt(docID); - if (singleValuePerDoc == false) { - if (longOrds) { - out.writeLong(ord); - } else { - assert ord <= Integer.MAX_VALUE; - out.writeInt((int) ord); - } - } + count++; + assert expectedCount == 0 || count <= expectedCount; + } + + @Override + public void append(BytesRef packedValue, int docID) throws IOException { + assert packedValue.length == packedBytesLength; + out.writeBytes(packedValue.bytes, packedValue.offset, packedValue.length); + out.writeInt(docID); count++; assert expectedCount == 0 || count <= expectedCount; } @Override public PointReader getReader(long start, long length) throws IOException { + byte[] buffer = new byte[packedBytesLength + Integer.BYTES]; + return getReader(start, length, buffer); + } + + protected OfflinePointReader getReader(long start, long length, byte[] reusableBuffer) throws IOException { assert closed; assert start + length <= count: "start=" + start + " length=" + length + " count=" + count; assert expectedCount == 0 || count == expectedCount; - return new OfflinePointReader(tempDir, name, packedBytesLength, start, length, longOrds, singleValuePerDoc); + return new OfflinePointReader(tempDir, name, packedBytesLength, start, length, reusableBuffer); } @Override - public PointReader getSharedReader(long start, long length, List toCloseHeroically) throws IOException { - if (sharedReader == null) { - assert start == 0; - assert length <= count; - sharedReader = new OfflinePointReader(tempDir, name, packedBytesLength, 0, count, longOrds, singleValuePerDoc); - toCloseHeroically.add(sharedReader); - // Make sure the OfflinePointReader intends to verify its checksum: - assert sharedReader.in instanceof ChecksumIndexInput; - } else { - assert start == nextSharedRead: "start=" + start + " length=" + length + " nextSharedRead=" + nextSharedRead; - } - nextSharedRead += length; - return sharedReader; + public long count() { + return count; } @Override public void close() throws IOException { if (closed == false) { - assert sharedReader == null; try { CodecUtil.writeFooter(out); } finally { @@ -125,12 +100,6 @@ public final class OfflinePointWriter implements PointWriter { @Override public void destroy() throws IOException { - if (sharedReader != null) { - // At this point, the shared reader should have done a full sweep of the file: - assert nextSharedRead == count; - sharedReader.close(); - sharedReader = null; - } tempDir.deleteFile(name); } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java index 0c31275d02b..c0eaff880de 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java @@ -20,62 +20,24 @@ package org.apache.lucene.util.bkd; import java.io.Closeable; import java.io.IOException; -import org.apache.lucene.util.LongBitSet; +import org.apache.lucene.util.BytesRef; /** One pass iterator through all points previously written with a - * {@link PointWriter}, abstracting away whether points a read + * {@link PointWriter}, abstracting away whether points are read * from (offline) disk or simple arrays in heap. * - * @lucene.internal */ + * @lucene.internal + * */ public abstract class PointReader implements Closeable { /** Returns false once iteration is done, else true. */ public abstract boolean next() throws IOException; - /** Returns the packed byte[] value */ - public abstract byte[] packedValue(); - - /** Point ordinal */ - public abstract long ord(); + /** Sets the packed value in the provided ByteRef */ + public abstract void packedValue(BytesRef bytesRef); /** DocID for this point */ public abstract int docID(); - /** Iterates through the next {@code count} ords, marking them in the provided {@code ordBitSet}. */ - public void markOrds(long count, LongBitSet ordBitSet) throws IOException { - for(int i=0;i toCloseHeroically) throws IOException; + /** Return the number of points in this writer */ + long count(); /** Removes any temp files behind this writer */ void destroy() throws IOException; + } diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java b/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java index 0d57bf83f51..570f95a1386 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java @@ -42,7 +42,7 @@ public class Test2BBKDPoints extends LuceneTestCase { final int numDocs = (Integer.MAX_VALUE / 26) + 100; BKDWriter w = new BKDWriter(numDocs, dir, "_0", 1, 1, Long.BYTES, - BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs, false); + BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs); int counter = 0; byte[] packedBytes = new byte[Long.BYTES]; for (int docID = 0; docID < numDocs; docID++) { @@ -79,7 +79,7 @@ public class Test2BBKDPoints extends LuceneTestCase { final int numDocs = (Integer.MAX_VALUE / 26) + 100; BKDWriter w = new BKDWriter(numDocs, dir, "_0", 2, 2, Long.BYTES, - BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs, false); + BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs); int counter = 0; byte[] packedBytes = new byte[2*Long.BYTES]; for (int docID = 0; docID < numDocs; docID++) { diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java index a01c92714cb..01d05a0f0d5 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java @@ -48,7 +48,7 @@ public class TestBKD extends LuceneTestCase { public void testBasicInts1D() throws Exception { try (Directory dir = getDirectory(100)) { - BKDWriter w = new BKDWriter(100, dir, "tmp", 1, 1, 4, 2, 1.0f, 100, true); + BKDWriter w = new BKDWriter(100, dir, "tmp", 1, 1, 4, 2, 1.0f, 100); byte[] scratch = new byte[4]; for(int docID=0;docID<100;docID++) { NumericUtils.intToSortableBytes(docID, scratch, 0); @@ -124,7 +124,7 @@ public class TestBKD extends LuceneTestCase { int numIndexDims = TestUtil.nextInt(random(), 1, numDims); int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100); float maxMB = (float) 3.0 + (3*random().nextFloat()); - BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numIndexDims, 4, maxPointsInLeafNode, maxMB, numDocs, true); + BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numIndexDims, 4, maxPointsInLeafNode, maxMB, numDocs); if (VERBOSE) { System.out.println("TEST: numDims=" + numDims + " numIndexDims=" + numIndexDims + " numDocs=" + numDocs); @@ -265,7 +265,7 @@ public class TestBKD extends LuceneTestCase { int numDims = TestUtil.nextInt(random(), 1, 5); int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100); float maxMB = (float) 3.0 + (3*random().nextFloat()); - BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numDims, numBytesPerDim, maxPointsInLeafNode, maxMB, numDocs, true); + BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numDims, numBytesPerDim, maxPointsInLeafNode, maxMB, numDocs); BigInteger[][] docs = new BigInteger[numDocs][]; byte[] scratch = new byte[numBytesPerDim*numDims]; @@ -441,7 +441,7 @@ public class TestBKD extends LuceneTestCase { public void testTooLittleHeap() throws Exception { try (Directory dir = getDirectory(0)) { IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> { - new BKDWriter(1, dir, "bkd", 1, 1, 16, 1000000, 0.001, 0, true); + new BKDWriter(1, dir, "bkd", 1, 1, 16, 1000000, 0.001, 0); }); assertTrue(expected.getMessage().contains("either increase maxMBSortInHeap or decrease maxPointsInLeafNode")); } @@ -668,7 +668,7 @@ public class TestBKD extends LuceneTestCase { List docMaps = null; int seg = 0; - BKDWriter w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length, false); + BKDWriter w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length); IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT); IndexInput in = null; @@ -728,7 +728,7 @@ public class TestBKD extends LuceneTestCase { seg++; maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 1000); maxMB = (float) 3.0 + (3*random().nextDouble()); - w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length, false); + w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length); lastDocIDBase = docID; } } @@ -749,7 +749,7 @@ public class TestBKD extends LuceneTestCase { out.close(); in = dir.openInput("bkd", IOContext.DEFAULT); seg++; - w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length, false); + w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length); List readers = new ArrayList<>(); for(long fp : toMerge) { in.seek(fp); @@ -924,7 +924,7 @@ public class TestBKD extends LuceneTestCase { @Override public IndexOutput createTempOutput(String prefix, String suffix, IOContext context) throws IOException { IndexOutput out = in.createTempOutput(prefix, suffix, context); - if (corrupted == false && prefix.equals("_0_bkd1") && suffix.equals("sort")) { + if (corrupted == false && prefix.equals("_0") && suffix.equals("bkd_left0")) { corrupted = true; return new CorruptingIndexOutput(dir0, 22, out); } else { @@ -1008,7 +1008,7 @@ public class TestBKD extends LuceneTestCase { public void testTieBreakOrder() throws Exception { try (Directory dir = newDirectory()) { int numDocs = 10000; - BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", 1, 1, Integer.BYTES, 2, 0.01f, numDocs, true); + BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", 1, 1, Integer.BYTES, 2, 0.01f, numDocs); for(int i=0;i= maxDocID); + } + assertTrue(Arrays.equals(partitionPoint, min)); + leftPointWriter.destroy(); + rightPointWriter.destroy(); + } + points.destroy(); + } + + private PointWriter copyPoints(Directory dir, PointWriter points, int packedLength) throws IOException { + BytesRef bytesRef = new BytesRef(); + + try (PointWriter copy = getRandomPointWriter(dir, points.count(), packedLength); + PointReader reader = points.getReader(0, points.count())) { + while (reader.next()) { + reader.packedValue(bytesRef); + copy.append(bytesRef, reader.docID()); + } + return copy; + } + } + + private PointWriter getRandomPointWriter(Directory dir, long numPoints, int packedBytesLength) throws IOException { + if (numPoints < 4096 && random().nextBoolean()) { + return new HeapPointWriter(Math.toIntExact(numPoints), Math.toIntExact(numPoints), packedBytesLength); + } else { + return new OfflinePointWriter(dir, "test", packedBytesLength, "data", numPoints); + } + } + + private Directory getDirectory(int numPoints) { + Directory dir; + if (numPoints > 100000) { + dir = newFSDirectory(createTempDir("TestBKDTRadixSelector")); + } else { + dir = newDirectory(); + } + return dir; + } + + private byte[] getMin(PointWriter p, long size, int bytesPerDimension, int dimension) throws IOException { + byte[] min = new byte[bytesPerDimension]; + Arrays.fill(min, (byte) 0xff); + try (PointReader reader = p.getReader(0, size)) { + byte[] value = new byte[bytesPerDimension]; + BytesRef packedValue = new BytesRef(); + while (reader.next()) { + reader.packedValue(packedValue); + System.arraycopy(packedValue.bytes, packedValue.offset + dimension * bytesPerDimension, value, 0, bytesPerDimension); + if (FutureArrays.compareUnsigned(min, 0, bytesPerDimension, value, 0, bytesPerDimension) > 0) { + System.arraycopy(value, 0, min, 0, bytesPerDimension); + } + } + } + return min; + } + + private int getMinDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { + int docID = Integer.MAX_VALUE; + try (PointReader reader = p.getReader(0, size)) { + BytesRef packedValue = new BytesRef(); + while (reader.next()) { + reader.packedValue(packedValue); + int offset = dimension * bytesPerDimension; + if (FutureArrays.compareUnsigned(packedValue.bytes, packedValue.offset + offset, packedValue.offset + offset + bytesPerDimension, partitionPoint, 0, bytesPerDimension) == 0) { + int newDocID = reader.docID(); + if (newDocID < docID) { + docID = newDocID; + } + } + } + } + return docID; + } + + private byte[] getMax(PointWriter p, long size, int bytesPerDimension, int dimension) throws IOException { + byte[] max = new byte[bytesPerDimension]; + Arrays.fill(max, (byte) 0); + try (PointReader reader = p.getReader(0, size)) { + byte[] value = new byte[bytesPerDimension]; + BytesRef packedValue = new BytesRef(); + while (reader.next()) { + reader.packedValue(packedValue); + System.arraycopy(packedValue.bytes, packedValue.offset + dimension * bytesPerDimension, value, 0, bytesPerDimension); + if (FutureArrays.compareUnsigned(max, 0, bytesPerDimension, value, 0, bytesPerDimension) < 0) { + System.arraycopy(value, 0, max, 0, bytesPerDimension); + } + } + } + return max; + } + + private int getMaxDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { + int docID = Integer.MIN_VALUE; + try (PointReader reader = p.getReader(0, size)) { + BytesRef packedValue = new BytesRef(); + while (reader.next()) { + reader.packedValue(packedValue); + int offset = dimension * bytesPerDimension; + if (FutureArrays.compareUnsigned(packedValue.bytes, packedValue.offset + offset, packedValue.offset + offset + bytesPerDimension, partitionPoint, 0, bytesPerDimension) == 0) { + int newDocID = reader.docID(); + if (newDocID > docID) { + docID = newDocID; + } + } + } + } + return docID; + } + +} diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java b/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java index b5a728ba598..ec3c323acc4 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java +++ b/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java @@ -106,7 +106,6 @@ public class RandomCodec extends AssertingCodec { public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException { PointValues values = reader.getValues(fieldInfo.name); - boolean singleValuePerDoc = values.size() == values.getDocCount(); try (BKDWriter writer = new RandomlySplittingBKDWriter(writeState.segmentInfo.maxDoc(), writeState.directory, @@ -117,7 +116,6 @@ public class RandomCodec extends AssertingCodec { maxPointsInLeafNode, maxMBSortInHeap, values.size(), - singleValuePerDoc, bkdSplitRandomSeed ^ fieldInfo.name.hashCode())) { values.intersect(new IntersectVisitor() { @Override @@ -262,12 +260,8 @@ public class RandomCodec extends AssertingCodec { public RandomlySplittingBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim, int maxPointsInLeafNode, double maxMBSortInHeap, - long totalPointCount, boolean singleValuePerDoc, int randomSeed) throws IOException { - super(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount, - getRandomSingleValuePerDoc(singleValuePerDoc, randomSeed), - getRandomLongOrds(totalPointCount, singleValuePerDoc, randomSeed), - getRandomOfflineSorterBufferMB(randomSeed), - getRandomOfflineSorterMaxTempFiles(randomSeed)); + long totalPointCount, int randomSeed) throws IOException { + super(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount); this.random = new Random(randomSeed); }