From 3ef6e015dd127b356e4163eb48411be94fb2d50c Mon Sep 17 00:00:00 2001 From: iverase Date: Wed, 20 Feb 2019 12:14:58 +0100 Subject: [PATCH] LUCENE-8699: Change HeapPointWriter to use a single byte array instead to a list of byte arrays. In addition a new interface PointValue is added to abstract out the different formats between offline and on-heap writers. --- lucene/CHANGES.txt | 4 + .../simpletext/SimpleTextBKDWriter.java | 69 ++-- .../lucene/util/bkd/BKDRadixSelector.java | 338 ++++++++++++------ .../org/apache/lucene/util/bkd/BKDWriter.java | 62 +--- .../lucene/util/bkd/HeapPointReader.java | 70 +++- .../lucene/util/bkd/HeapPointWriter.java | 124 ++----- .../lucene/util/bkd/OfflinePointReader.java | 66 +++- .../lucene/util/bkd/OfflinePointWriter.java | 22 +- .../apache/lucene/util/bkd/PointReader.java | 11 +- .../apache/lucene/util/bkd/PointValue.java | 36 ++ .../apache/lucene/util/bkd/PointWriter.java | 9 +- .../lucene/util/bkd/TestBKDRadixSelector.java | 76 ++-- 12 files changed, 508 insertions(+), 379 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/bkd/PointValue.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 01849470bed..a339e04d142 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -24,6 +24,10 @@ Improvements * LUCENE-8687: Optimise radix partitioning for points on heap. (Ignacio Vera) +* LUCENE-8699: Change HeapPointWriter to use a single byte array instead to a list + of byte arrays. In addition a new interface PointValue is added to abstract out + the different formats between offline and on-heap writers. (Ignacio Vera) + Other * LUCENE-8680: Refactor EdgeTree#relateTriangle method. (Ignacio Vera) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java index 84642842b70..9f4938783d7 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java @@ -38,7 +38,6 @@ import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FutureArrays; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.MSBRadixSorter; import org.apache.lucene.util.NumericUtils; import org.apache.lucene.util.bkd.BKDRadixSelector; import org.apache.lucene.util.bkd.BKDWriter; @@ -46,6 +45,7 @@ import org.apache.lucene.util.bkd.HeapPointWriter; import org.apache.lucene.util.bkd.MutablePointsReaderUtils; import org.apache.lucene.util.bkd.OfflinePointWriter; import org.apache.lucene.util.bkd.PointReader; +import org.apache.lucene.util.bkd.PointValue; import org.apache.lucene.util.bkd.PointWriter; import static org.apache.lucene.codecs.simpletext.SimpleTextPointsWriter.BLOCK_COUNT; @@ -188,7 +188,7 @@ final class SimpleTextBKDWriter implements Closeable { } // We write first maxPointsSortInHeap in heap, then cutover to offline for additional points: - heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength); + heapPointWriter = new HeapPointWriter(maxPointsSortInHeap, packedBytesLength); this.maxMBSortInHeap = maxMBSortInHeap; } @@ -226,8 +226,7 @@ final class SimpleTextBKDWriter implements Closeable { for(int i=0;i= maxPointsSortInHeap) { if (offlinePointWriter == null) { spillToOffline(); @@ -565,37 +568,6 @@ final class SimpleTextBKDWriter implements Closeable { } } - // TODO: if we fixed each partition step to just record the file offset at the "split point", we could probably handle variable length - // encoding and not have our own ByteSequencesReader/Writer - - /** Sort the heap writer by the specified dim */ - private void sortHeapPointWriter(final HeapPointWriter writer, int from, int to, int dim, int commonPrefixLength) { - // Tie-break by docID: - new MSBRadixSorter(bytesPerDim + Integer.BYTES - commonPrefixLength) { - - @Override - protected int byteAt(int i, int k) { - assert k >= 0; - if (k + commonPrefixLength < bytesPerDim) { - // dim bytes - int block = i / writer.valuesPerBlock; - int index = i % writer.valuesPerBlock; - return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k + commonPrefixLength] & 0xff; - } else { - // doc id - int s = 3 - (k + commonPrefixLength - bytesPerDim); - return (writer.docIDs[i] >>> (s * 8)) & 0xff; - } - } - - @Override - protected void swap(int i, int j) { - writer.swap(i, j); - } - - }.sort(from, to); - } - private void checkMaxLeafNodeCount(int numLeaves) { if ((1+bytesPerDim) * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) { throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex"); @@ -864,12 +836,11 @@ final class SimpleTextBKDWriter implements Closeable { // Not inside the try because we don't want to close it here: try (PointReader reader = source.getReader(0, count); - HeapPointWriter writer = new HeapPointWriter(count, count, packedBytesLength)) { + HeapPointWriter writer = new HeapPointWriter(count, packedBytesLength)) { for(int i=0;i 1; + assert partitionSlices.length > 1 : "[partition alices] must be > 1, got " + partitionSlices.length; //If we are on heap then we just select on heap if (points.writer instanceof HeapPointWriter) { @@ -104,26 +107,13 @@ public final class BKDRadixSelector { return partition; } - //reset histogram - for (int i = 0; i < bytesSorted; i++) { - Arrays.fill(histogram[i], 0); - } OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points.writer; - //find common prefix from dimCommonPrefix, it does already set histogram values if needed - int commonPrefix = findCommonPrefix(offlinePointWriter, from, to, dim, dimCommonPrefix); - try (PointWriter left = getPointWriter(partitionPoint - from, "left" + dim); PointWriter right = getPointWriter(to - partitionPoint, "right" + dim)) { partitionSlices[0] = new PathSlice(left, 0, partitionPoint - from); partitionSlices[1] = new PathSlice(right, 0, to - partitionPoint); - //if all equals we just partition the points - if (commonPrefix == bytesSorted) { - offlinePartition(offlinePointWriter, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint); - return partitionPointFromCommonPrefix(); - } - //let's rock'n'roll - return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, commonPrefix, dim); + return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, dimCommonPrefix, dim); } } @@ -136,70 +126,98 @@ public final class BKDRadixSelector { } } - private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{ + private int findCommonPrefixAndHistogram(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{ //find common prefix - byte[] commonPrefix = new byte[bytesSorted]; int commonPrefixPosition = bytesSorted; + final int offset = dim * bytesPerDim; try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { assert commonPrefixPosition > dimCommonPrefix; reader.next(); - reader.packedValueWithDocId(bytesRef1); + PointValue pointValue = reader.pointValue(); // copy dimension - System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, commonPrefix, 0, bytesPerDim); + BytesRef packedValue = pointValue.packedValue(); + System.arraycopy(packedValue.bytes, packedValue.offset + offset, scratch, 0, bytesPerDim); // copy docID - System.arraycopy(bytesRef1.bytes, bytesRef1.offset + packedBytesLength, commonPrefix, bytesPerDim, Integer.BYTES); - for (long i = from + 1; i< to; i++) { + BytesRef docIDBytes = pointValue.docIDBytes(); + System.arraycopy(docIDBytes.bytes, docIDBytes.offset, scratch, bytesPerDim, Integer.BYTES); + for (long i = from + 1; i < to; i++) { reader.next(); - reader.packedValueWithDocId(bytesRef1); - int startIndex = (dimCommonPrefix > bytesPerDim) ? bytesPerDim : dimCommonPrefix; - int endIndex = (commonPrefixPosition > bytesPerDim) ? bytesPerDim : commonPrefixPosition; - int j = FutureArrays.mismatch(commonPrefix, startIndex, endIndex, bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim + startIndex, bytesRef1.offset + dim * bytesPerDim + endIndex); - if (j == 0) { - commonPrefixPosition = dimCommonPrefix; - break; - } else if (j == -1) { - if (commonPrefixPosition > bytesPerDim) { - //tie-break on docID - int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim); - if (k != -1) { - commonPrefixPosition = bytesPerDim + k; - } + pointValue = reader.pointValue(); + if (commonPrefixPosition == dimCommonPrefix) { + histogram[getBucket(offset, commonPrefixPosition, pointValue)]++; + // we do not need to check for common prefix anymore, + // just finish the histogram and break + for (long j = i + 1; j < to; j++) { + reader.next(); + pointValue = reader.pointValue(); + histogram[getBucket(offset, commonPrefixPosition, pointValue)]++; } + break; } else { - commonPrefixPosition = dimCommonPrefix + j; + //check common prefix and adjust histogram + final int startIndex = (dimCommonPrefix > bytesPerDim) ? bytesPerDim : dimCommonPrefix; + final int endIndex = (commonPrefixPosition > bytesPerDim) ? bytesPerDim : commonPrefixPosition; + packedValue = pointValue.packedValue(); + int j = FutureArrays.mismatch(scratch, startIndex, endIndex, packedValue.bytes, packedValue.offset + offset + startIndex, packedValue.offset + offset + endIndex); + if (j == -1) { + if (commonPrefixPosition > bytesPerDim) { + //tie-break on docID + docIDBytes = pointValue.docIDBytes(); + int k = FutureArrays.mismatch(scratch, bytesPerDim, commonPrefixPosition, docIDBytes.bytes, docIDBytes.offset, docIDBytes.offset + commonPrefixPosition - bytesPerDim); + if (k != -1) { + commonPrefixPosition = bytesPerDim + k; + Arrays.fill(histogram, 0); + histogram[scratch[commonPrefixPosition] & 0xff] = i - from; + } + } + } else { + commonPrefixPosition = dimCommonPrefix + j; + Arrays.fill(histogram, 0); + histogram[scratch[commonPrefixPosition] & 0xff] = i - from; + } + if (commonPrefixPosition != bytesSorted) { + histogram[getBucket(offset, commonPrefixPosition, pointValue)]++; + } } } } - //build histogram up to the common prefix + //build partition buckets up to commonPrefix for (int i = 0; i < commonPrefixPosition; i++) { - partitionBucket[i] = commonPrefix[i] & 0xff; - histogram[i][partitionBucket[i]] = to - from; + partitionBucket[i] = scratch[i] & 0xff; } return commonPrefixPosition; } + private int getBucket(int offset, int commonPrefixPosition, PointValue pointValue) { + int bucket; + if (commonPrefixPosition < bytesPerDim) { + BytesRef packedValue = pointValue.packedValue(); + bucket = packedValue.bytes[packedValue.offset + offset + commonPrefixPosition] & 0xff; + } else { + BytesRef docIDValue = pointValue.docIDBytes(); + bucket = docIDValue.bytes[docIDValue.offset + commonPrefixPosition - bytesPerDim] & 0xff; + } + return bucket; + } + private byte[] buildHistogramAndPartition(OfflinePointWriter points, PointWriter left, PointWriter right, - long from, long to, long partitionPoint, int iteration, int commonPrefix, int dim) throws IOException { + long from, long to, long partitionPoint, int iteration, int baseCommonPrefix, int dim) throws IOException { + //find common prefix from baseCommonPrefix and build histogram + int commonPrefix = findCommonPrefixAndHistogram(points, from, to, dim, baseCommonPrefix); + + //if all equals we just partition the points + if (commonPrefix == bytesSorted) { + offlinePartition(points, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint); + return partitionPointFromCommonPrefix(); + } long leftCount = 0; long rightCount = 0; - //build histogram at the commonPrefix byte - try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { - while (reader.next()) { - reader.packedValueWithDocId(bytesRef1); - int bucket; - if (commonPrefix < bytesPerDim) { - bucket = bytesRef1.bytes[bytesRef1.offset + dim * bytesPerDim + commonPrefix] & 0xff; - } else { - bucket = bytesRef1.bytes[bytesRef1.offset + packedBytesLength + commonPrefix - bytesPerDim] & 0xff; - } - histogram[commonPrefix][bucket]++; - } - } + //Count left points and record the partition point for(int i = 0; i < HISTOGRAM_SIZE; i++) { - long size = histogram[commonPrefix][i]; + long size = histogram[i]; if (leftCount + size > partitionPoint - from) { partitionBucket[commonPrefix] = i; break; @@ -208,13 +226,13 @@ public final class BKDRadixSelector { } //Count right points for(int i = partitionBucket[commonPrefix] + 1; i < HISTOGRAM_SIZE; i++) { - rightCount += histogram[commonPrefix][i]; + rightCount += histogram[i]; } - long delta = histogram[commonPrefix][partitionBucket[commonPrefix]]; - assert leftCount + rightCount + delta == to - from; + long delta = histogram[partitionBucket[commonPrefix]]; + assert leftCount + rightCount + delta == to - from : (leftCount + rightCount + delta) + " / " + (to - from); - //special case when be have lot of points that are equal + //special case when points are equal except last byte, we can just tie-break if (commonPrefix == bytesSorted - 1) { long tieBreakCount =(partitionPoint - from - leftCount); offlinePartition(points, left, right, null, from, to, dim, commonPrefix, tieBreakCount); @@ -241,35 +259,28 @@ public final class BKDRadixSelector { private void offlinePartition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints, long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException { assert bytePosition == bytesSorted -1 || deltaPoints != null; + int offset = dim * bytesPerDim; long tiebreakCounter = 0; try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) { while (reader.next()) { - reader.packedValueWithDocId(bytesRef1); - reader.packedValue(bytesRef2); - int docID = reader.docID(); - int bucket; - if (bytePosition < bytesPerDim) { - bucket = bytesRef1.bytes[bytesRef1.offset + dim * bytesPerDim + bytePosition] & 0xff; - } else { - bucket = bytesRef1.bytes[bytesRef1.offset + packedBytesLength + bytePosition - bytesPerDim] & 0xff; - } - //int bucket = getBucket(bytesRef1, dim, thisCommonPrefix); + PointValue pointValue = reader.pointValue(); + int bucket = getBucket(offset, bytePosition, pointValue); if (bucket < this.partitionBucket[bytePosition]) { // to the left side - left.append(bytesRef2, docID); + left.append(pointValue); } else if (bucket > this.partitionBucket[bytePosition]) { // to the right side - right.append(bytesRef2, docID); + right.append(pointValue); } else { if (bytePosition == bytesSorted - 1) { if (tiebreakCounter < numDocsTiebreak) { - left.append(bytesRef2, docID); + left.append(pointValue); tiebreakCounter++; } else { - right.append(bytesRef2, docID); + right.append(pointValue); } } else { - deltaPoints.append(bytesRef2, docID); + deltaPoints.append(pointValue); } } } @@ -287,24 +298,21 @@ public final class BKDRadixSelector { } private byte[] heapPartition(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException { - byte[] partition = heapRadixSelect(points, dim, from, to, partitionPoint, commonPrefix); - for (int i = from; i < to; i++) { - points.getPackedValueSlice(i, bytesRef1); - int docID = points.docIDs[i]; + PointValue value = points.getPackedValueSlice(i); if (i < partitionPoint) { - left.append(bytesRef1, docID); + left.append(value); } else { - right.append(bytesRef1, docID); + right.append(value); } } - return partition; } private byte[] heapRadixSelect(HeapPointWriter points, int dim, int from, int to, int partitionPoint, int commonPrefix) { final int offset = dim * bytesPerDim + commonPrefix; + final int dimCmpBytes = bytesPerDim - commonPrefix; new RadixSelector(bytesSorted - commonPrefix) { @Override @@ -314,29 +322,156 @@ public final class BKDRadixSelector { @Override protected int byteAt(int i, int k) { - assert k >= 0; - if (k + commonPrefix < bytesPerDim) { + assert k >= 0 : "negative prefix " + k; + if (k < dimCmpBytes) { // dim bytes - int block = i / points.valuesPerBlock; - int index = i % points.valuesPerBlock; - return points.blocks.get(block)[index * packedBytesLength + offset + k] & 0xff; + return points.block[i * packedBytesLength + offset + k] & 0xff; } else { // doc id - int s = 3 - (k + commonPrefix - bytesPerDim); + int s = 3 - (k - dimCmpBytes); return (points.docIDs[i] >>> (s * 8)) & 0xff; } } + + @Override + protected Selector getFallbackSelector(int d) { + int skypedBytes = d + commonPrefix; + final int start = dim * bytesPerDim + skypedBytes; + final int end = dim * bytesPerDim + bytesPerDim; + return new IntroSelector() { + + int pivotDoc = -1; + + @Override + protected void swap(int i, int j) { + points.swap(i, j); + } + + @Override + protected void setPivot(int i) { + if (skypedBytes < bytesPerDim) { + System.arraycopy(points.block, i * packedBytesLength + dim * bytesPerDim, scratch, 0, bytesPerDim); + } + pivotDoc = points.docIDs[i]; + } + + @Override + protected int compare(int i, int j) { + if (skypedBytes < bytesPerDim) { + int iOffset = i * packedBytesLength; + int jOffset = j * packedBytesLength; + int cmp = FutureArrays.compareUnsigned(points.block, iOffset + start, iOffset + end, + points.block, jOffset + start, jOffset + end); + if (cmp != 0) { + return cmp; + } + } + return points.docIDs[i] - points.docIDs[j]; + } + + @Override + protected int comparePivot(int j) { + if (skypedBytes < bytesPerDim) { + int jOffset = j * packedBytesLength; + int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, bytesPerDim, + points.block, jOffset + start, jOffset + end); + if (cmp != 0) { + return cmp; + } + } + return pivotDoc - points.docIDs[j]; + } + }; + } }.select(from, to, partitionPoint); byte[] partition = new byte[bytesPerDim]; - points.getPackedValueSlice(partitionPoint, bytesRef1); - System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, partition, 0, bytesPerDim); + PointValue pointValue = points.getPackedValueSlice(partitionPoint); + BytesRef packedValue = pointValue.packedValue(); + System.arraycopy(packedValue.bytes, packedValue.offset + dim * bytesPerDim, partition, 0, bytesPerDim); return partition; } + /** Sort the heap writer by the specified dim. It is used to sort the leaves of the tree */ + public void heapRadixSort(final HeapPointWriter points, int from, int to, int dim, int commonPrefixLength) { + final int offset = dim * bytesPerDim + commonPrefixLength; + final int dimCmpBytes = bytesPerDim - commonPrefixLength; + new MSBRadixSorter(bytesSorted - commonPrefixLength) { + + @Override + protected int byteAt(int i, int k) { + assert k >= 0 : "negative prefix " + k; + if (k < dimCmpBytes) { + // dim bytes + return points.block[i * packedBytesLength + offset + k] & 0xff; + } else { + // doc id + int s = 3 - (k - dimCmpBytes); + return (points.docIDs[i] >>> (s * 8)) & 0xff; + } + } + + @Override + protected void swap(int i, int j) { + points.swap(i, j); + } + + @Override + protected Sorter getFallbackSorter(int k) { + int skypedBytes = k + commonPrefixLength; + final int start = dim * bytesPerDim + skypedBytes; + final int end = dim * bytesPerDim + bytesPerDim; + return new IntroSorter() { + + int pivotDoc = -1; + + @Override + protected void swap(int i, int j) { + points.swap(i, j); + } + + @Override + protected void setPivot(int i) { + if (skypedBytes < bytesPerDim) { + System.arraycopy(points.block, i * packedBytesLength + dim * bytesPerDim, scratch, 0, bytesPerDim); + } + pivotDoc = points.docIDs[i]; + } + + @Override + protected int compare(int i, int j) { + if (skypedBytes < bytesPerDim) { + int iOffset = i * packedBytesLength; + int jOffset = j * packedBytesLength; + int cmp = FutureArrays.compareUnsigned(points.block, iOffset + start, iOffset + end, + points.block, jOffset + start, jOffset + end); + if (cmp != 0) { + return cmp; + } + } + return points.docIDs[i] - points.docIDs[j]; + } + + @Override + protected int comparePivot(int j) { + if (skypedBytes < bytesPerDim) { + int jOffset = j * packedBytesLength; + int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, bytesPerDim, + points.block, jOffset + start, jOffset + end); + if (cmp != 0) { + return cmp; + } + } + return pivotDoc - points.docIDs[j]; + } + }; + } + }.sort(from, to); + } + private PointWriter getDeltaPointWriter(PointWriter left, PointWriter right, long delta, int iteration) throws IOException { if (delta <= getMaxPointsSortInHeap(left, right)) { - return new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength); + return new HeapPointWriter(Math.toIntExact(delta), packedBytesLength); } else { return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta); } @@ -345,10 +480,10 @@ public final class BKDRadixSelector { private int getMaxPointsSortInHeap(PointWriter left, PointWriter right) { int pointsUsed = 0; if (left instanceof HeapPointWriter) { - pointsUsed += ((HeapPointWriter) left).maxSize; + pointsUsed += ((HeapPointWriter) left).size; } if (right instanceof HeapPointWriter) { - pointsUsed += ((HeapPointWriter) right).maxSize; + pointsUsed += ((HeapPointWriter) right).size; } assert maxPointsSortInHeap >= pointsUsed; return maxPointsSortInHeap - pointsUsed; @@ -359,7 +494,7 @@ public final class BKDRadixSelector { //max size for these objects is half of the total points we can have on-heap. if (count <= maxPointsSortInHeap / 2) { int size = Math.toIntExact(count); - return new HeapPointWriter(size, size, packedBytesLength); + return new HeapPointWriter(size, packedBytesLength); } else { return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count); } @@ -382,5 +517,4 @@ public final class BKDRadixSelector { return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")"; } } - } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java index 3e873782658..a734a68ae7d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java @@ -41,7 +41,6 @@ import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FutureArrays; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.MSBRadixSorter; import org.apache.lucene.util.NumericUtils; import org.apache.lucene.util.PriorityQueue; @@ -181,7 +180,7 @@ public class BKDWriter implements Closeable { } // We write first maxPointsSortInHeap in heap, then cutover to offline for additional points: - heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength); + heapPointWriter = new HeapPointWriter(maxPointsSortInHeap, packedBytesLength); this.maxMBSortInHeap = maxMBSortInHeap; } @@ -215,10 +214,8 @@ public class BKDWriter implements Closeable { // For each .add we just append to this input file, then in .finish we sort this input and resursively build the tree: offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "spill", 0); tempInput = offlinePointWriter.out; - scratchBytesRef1.length = packedBytesLength; for(int i=0;i= 0; - if (k + commonPrefixLength < bytesPerDim) { - // dim bytes - int block = i / writer.valuesPerBlock; - int index = i % writer.valuesPerBlock; - return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k + commonPrefixLength] & 0xff; - } else { - // doc id - int s = 3 - (k + commonPrefixLength - bytesPerDim); - return (writer.docIDs[i] >>> (s * 8)) & 0xff; - } - } - - @Override - protected void swap(int i, int j) { - writer.swap(i, j); - } - - }.sort(from, to); - } - // useful for debugging: /* private void printPathSlice(String desc, PathSlice slice, int dim) throws IOException { @@ -1264,12 +1233,11 @@ public class BKDWriter implements Closeable { // Not inside the try because we don't want to close it here: try (PointReader reader = source.getReader(0, source.count()); - HeapPointWriter writer = new HeapPointWriter(count, count, packedBytesLength)) { + HeapPointWriter writer = new HeapPointWriter(count, packedBytesLength)) { for(int i=0;i blocks; - final int valuesPerBlock; + final byte[] block; final int packedBytesLength; final int[] docIDs; final int end; + private final HeapPointValue pointValue; - public HeapPointReader(List blocks, int valuesPerBlock, int packedBytesLength, int[] docIDs, int start, int end) { - this.blocks = blocks; - this.valuesPerBlock = valuesPerBlock; + public HeapPointReader(byte[] block, int packedBytesLength, int[] docIDs, int start, int end) { + this.block = block; this.docIDs = docIDs; curRead = start-1; this.end = end; this.packedBytesLength = packedBytesLength; + this.pointValue = new HeapPointValue(block, packedBytesLength); } @Override @@ -49,20 +47,54 @@ public final class HeapPointReader extends PointReader { } @Override - public void packedValue(BytesRef bytesRef) { - int block = curRead / valuesPerBlock; - int blockIndex = curRead % valuesPerBlock; - bytesRef.bytes = blocks.get(block); - bytesRef.offset = blockIndex * packedBytesLength; - bytesRef.length = packedBytesLength; - } - - @Override - public int docID() { - return docIDs[curRead]; + public PointValue pointValue() { + pointValue.setValue(curRead * packedBytesLength, docIDs[curRead]); + return pointValue; } @Override public void close() { } + + /** + * Reusable implementation for a point value on-heap + */ + static class HeapPointValue implements PointValue { + + BytesRef packedValue; + BytesRef docIDBytes; + int docID; + + public HeapPointValue(byte[] value, int packedLength) { + packedValue = new BytesRef(value, 0, packedLength); + docIDBytes = new BytesRef(new byte[4]); + } + + /** + * Sets a new value by changing the offset and docID. + */ + public void setValue(int offset, int docID) { + this.docID = docID; + packedValue.offset = offset; + } + + @Override + public BytesRef packedValue() { + return packedValue; + } + + @Override + public int docID() { + return docID; + } + + @Override + public BytesRef docIDBytes() { + docIDBytes.bytes[0] = (byte) (docID >> 24); + docIDBytes.bytes[1] = (byte) (docID >> 16); + docIDBytes.bytes[2] = (byte) (docID >> 8); + docIDBytes.bytes[3] = (byte) (docID >> 0); + return docIDBytes; + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java index 0e4ad782f5f..8915b0c27e3 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java @@ -16,10 +16,6 @@ */ package org.apache.lucene.util.bkd; -import java.util.ArrayList; -import java.util.List; - -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; /** @@ -28,97 +24,51 @@ import org.apache.lucene.util.BytesRef; * @lucene.internal * */ public final class HeapPointWriter implements PointWriter { - public int[] docIDs; + public final int[] docIDs; + public final byte[] block; + final int size; + final int packedBytesLength; + private final byte[] scratch; private int nextWrite; private boolean closed; - final int maxSize; - public final int valuesPerBlock; - final int packedBytesLength; - // NOTE: can't use ByteBlockPool because we need random-write access when sorting in heap - public final List blocks = new ArrayList<>(); - private byte[] scratch; + + private HeapPointReader.HeapPointValue offlinePointValue; - public HeapPointWriter(int initSize, int maxSize, int packedBytesLength) { - docIDs = new int[initSize]; - this.maxSize = maxSize; + public HeapPointWriter(int size, int packedBytesLength) { + this.docIDs = new int[size]; + this.block = new byte[packedBytesLength * size]; + this.size = size; this.packedBytesLength = packedBytesLength; - // 4K per page, unless each value is > 4K: - valuesPerBlock = Math.max(1, 4096/packedBytesLength); - scratch = new byte[packedBytesLength]; - } - - public void copyFrom(HeapPointWriter other) { - if (docIDs.length < other.nextWrite) { - throw new IllegalStateException("docIDs.length=" + docIDs.length + " other.nextWrite=" + other.nextWrite); - } - System.arraycopy(other.docIDs, 0, docIDs, 0, other.nextWrite); - for(byte[] block : other.blocks) { - blocks.add(block.clone()); - } - nextWrite = other.nextWrite; + this.scratch = new byte[packedBytesLength]; + offlinePointValue = new HeapPointReader.HeapPointValue(block, packedBytesLength); } /** Returns a reference, in result, to the byte[] slice holding this value */ - public void getPackedValueSlice(int index, BytesRef result) { - int block = index / valuesPerBlock; - int blockIndex = index % valuesPerBlock; - result.bytes = blocks.get(block); - result.offset = blockIndex * packedBytesLength; - result.length = packedBytesLength; - } - - void writePackedValue(int index, byte[] bytes) { - assert bytes.length == packedBytesLength; - int block = index / valuesPerBlock; - int blockIndex = index % valuesPerBlock; - //System.out.println("writePackedValue: index=" + index + " bytes.length=" + bytes.length + " block=" + block + " blockIndex=" + blockIndex + " valuesPerBlock=" + valuesPerBlock); - while (blocks.size() <= block) { - // If this is the last block, only allocate as large as necessary for maxSize: - int valuesInBlock = Math.min(valuesPerBlock, maxSize - (blocks.size() * valuesPerBlock)); - blocks.add(new byte[valuesInBlock*packedBytesLength]); - } - System.arraycopy(bytes, 0, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength); - } - - void writePackedValue(int index, BytesRef bytes) { - assert bytes.length == packedBytesLength; - int block = index / valuesPerBlock; - int blockIndex = index % valuesPerBlock; - //System.out.println("writePackedValue: index=" + index + " bytes.length=" + bytes.length + " block=" + block + " blockIndex=" + blockIndex + " valuesPerBlock=" + valuesPerBlock); - while (blocks.size() <= block) { - // If this is the last block, only allocate as large as necessary for maxSize: - int valuesInBlock = Math.min(valuesPerBlock, maxSize - (blocks.size() * valuesPerBlock)); - blocks.add(new byte[valuesInBlock*packedBytesLength]); - } - System.arraycopy(bytes.bytes, bytes.offset, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength); + public PointValue getPackedValueSlice(int index) { + assert index < nextWrite : "nextWrite=" + (nextWrite) + " vs index=" + index; + offlinePointValue.setValue(index * packedBytesLength, docIDs[index]); + return offlinePointValue; } @Override public void append(byte[] packedValue, int docID) { - assert closed == false; - assert packedValue.length == packedBytesLength; - if (docIDs.length == nextWrite) { - int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, Integer.BYTES)); - assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite; - docIDs = ArrayUtil.growExact(docIDs, nextSize); - } - writePackedValue(nextWrite, packedValue); + assert closed == false : "point writer is already closed"; + assert packedValue.length == packedBytesLength : "[packedValue] must have length [" + packedBytesLength + "] but was [" + packedValue.length + "]"; + assert nextWrite < size : "nextWrite=" + (nextWrite + 1) + " vs size=" + size; + System.arraycopy(packedValue, 0, block, nextWrite * packedBytesLength, packedBytesLength); docIDs[nextWrite] = docID; nextWrite++; } @Override - public void append(BytesRef packedValue, int docID) { - assert closed == false; - assert packedValue.length == packedBytesLength; - if (docIDs.length == nextWrite) { - int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, Integer.BYTES)); - assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite; - docIDs = ArrayUtil.growExact(docIDs, nextSize); - } - writePackedValue(nextWrite, packedValue); - docIDs[nextWrite] = docID; + public void append(PointValue pointValue) { + assert closed == false : "point writer is already closed"; + assert nextWrite < size : "nextWrite=" + (nextWrite + 1) + " vs size=" + size; + BytesRef packedValue = pointValue.packedValue(); + assert packedValue.length == packedBytesLength : "[packedValue] must have length [" + (packedBytesLength) + "] but was [" + packedValue.length + "]"; + System.arraycopy(packedValue.bytes, packedValue.offset, block, nextWrite * packedBytesLength, packedBytesLength); + docIDs[nextWrite] = pointValue.docID(); nextWrite++; } @@ -127,18 +77,15 @@ public final class HeapPointWriter implements PointWriter { docIDs[i] = docIDs[j]; docIDs[j] = docID; - - byte[] blockI = blocks.get(i / valuesPerBlock); - int indexI = (i % valuesPerBlock) * packedBytesLength; - byte[] blockJ = blocks.get(j / valuesPerBlock); - int indexJ = (j % valuesPerBlock) * packedBytesLength; + int indexI = i * packedBytesLength; + int indexJ = j * packedBytesLength; // scratch1 = values[i] - System.arraycopy(blockI, indexI, scratch, 0, packedBytesLength); + System.arraycopy(block, indexI, scratch, 0, packedBytesLength); // values[i] = values[j] - System.arraycopy(blockJ, indexJ, blockI, indexI, packedBytesLength); + System.arraycopy(block, indexJ, block, indexI, packedBytesLength); // values[j] = scratch1 - System.arraycopy(scratch, 0, blockJ, indexJ, packedBytesLength); + System.arraycopy(scratch, 0, block, indexJ, packedBytesLength); } @Override @@ -148,9 +95,10 @@ public final class HeapPointWriter implements PointWriter { @Override public PointReader getReader(long start, long length) { + assert closed : "point writer is still open and trying to get a reader"; assert start + length <= docIDs.length: "start=" + start + " length=" + length + " docIDs.length=" + docIDs.length; assert start + length <= nextWrite: "start=" + start + " length=" + length + " nextWrite=" + nextWrite; - return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, docIDs, (int) start, Math.toIntExact(start+length)); + return new HeapPointReader(block, packedBytesLength, docIDs, (int) start, Math.toIntExact(start+length)); } @Override @@ -164,6 +112,6 @@ public final class HeapPointWriter implements PointWriter { @Override public String toString() { - return "HeapPointWriter(count=" + nextWrite + " alloc=" + docIDs.length + ")"; + return "HeapPointWriter(count=" + nextWrite + " size=" + docIDs.length + ")"; } } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java index 86afc790c62..6218429b071 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java @@ -31,12 +31,12 @@ import org.apache.lucene.util.BytesRef; * * @lucene.internal * */ -public final class OfflinePointReader extends PointReader { +public final class OfflinePointReader implements PointReader { long countLeft; final IndexInput in; byte[] onHeapBuffer; - private int offset; + int offset; final int bytesPerDoc; private boolean checked; private final int packedValueLength; @@ -44,6 +44,7 @@ public final class OfflinePointReader extends PointReader { private final int maxPointOnHeap; // File name we are reading final String name; + private final OfflinePointValue pointValue; public OfflinePointReader(Directory tempDir, String tempFileName, int packedBytesLength, long start, long length, byte[] reusableBuffer) throws IOException { this.bytesPerDoc = packedBytesLength + Integer.BYTES; @@ -79,6 +80,7 @@ public final class OfflinePointReader extends PointReader { in.seek(seekFP); countLeft = length; this.onHeapBuffer = reusableBuffer; + this.pointValue = new OfflinePointValue(onHeapBuffer, packedValueLength); } @Override @@ -112,23 +114,9 @@ public final class OfflinePointReader extends PointReader { } @Override - public void packedValue(BytesRef bytesRef) { - bytesRef.bytes = onHeapBuffer; - bytesRef.offset = offset; - bytesRef.length = packedValueLength; - } - - protected void packedValueWithDocId(BytesRef bytesRef) { - bytesRef.bytes = onHeapBuffer; - bytesRef.offset = offset; - bytesRef.length = bytesPerDoc; - } - - @Override - public int docID() { - int position = this.offset + packedValueLength; - return ((onHeapBuffer[position++] & 0xFF) << 24) | ((onHeapBuffer[position++] & 0xFF) << 16) - | ((onHeapBuffer[position++] & 0xFF) << 8) | (onHeapBuffer[position++] & 0xFF); + public PointValue pointValue() { + pointValue.setOffset(offset); + return pointValue; } @Override @@ -143,5 +131,45 @@ public final class OfflinePointReader extends PointReader { in.close(); } } + + /** + * Reusable implementation for a point value offline + */ + static class OfflinePointValue implements PointValue { + + BytesRef packedValue; + BytesRef docIDBytes; + + OfflinePointValue(byte[] value, int packedValueLength) { + packedValue = new BytesRef(value, 0, packedValueLength); + docIDBytes = new BytesRef(value, packedValueLength, Integer.BYTES); + } + + /** + * Sets a new value by changing the offset. + */ + public void setOffset(int offset) { + packedValue.offset = offset; + docIDBytes.offset = offset + packedValue.length; + } + + @Override + public BytesRef packedValue() { + return packedValue; + } + + @Override + public int docID() { + int position =docIDBytes.offset; + return ((docIDBytes.bytes[position] & 0xFF) << 24) | ((docIDBytes.bytes[++position] & 0xFF) << 16) + | ((docIDBytes.bytes[++position] & 0xFF) << 8) | (docIDBytes.bytes[++position] & 0xFF); + } + + @Override + public BytesRef docIDBytes() { + return docIDBytes; + } + } + } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java index 5479b531dbe..4f4ce9ee740 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java @@ -46,26 +46,30 @@ public final class OfflinePointWriter implements PointWriter { this.name = out.getName(); this.tempDir = tempDir; this.packedBytesLength = packedBytesLength; - this.expectedCount = expectedCount; } - + @Override public void append(byte[] packedValue, int docID) throws IOException { - assert packedValue.length == packedBytesLength; + assert closed == false : "Point writer is already closed"; + assert packedValue.length == packedBytesLength : "[packedValue] must have length [" + packedBytesLength + "] but was [" + packedValue.length + "]"; out.writeBytes(packedValue, 0, packedValue.length); out.writeInt(docID); count++; - assert expectedCount == 0 || count <= expectedCount; + assert expectedCount == 0 || count <= expectedCount: "expectedCount=" + expectedCount + " vs count=" + count; } @Override - public void append(BytesRef packedValue, int docID) throws IOException { - assert packedValue.length == packedBytesLength; + public void append(PointValue pointValue) throws IOException { + assert closed == false : "Point writer is already closed"; + BytesRef packedValue = pointValue.packedValue(); + assert packedValue.length == packedBytesLength : "[packedValue] must have length [" + packedBytesLength + "] but was [" + packedValue.length + "]"; out.writeBytes(packedValue.bytes, packedValue.offset, packedValue.length); - out.writeInt(docID); + BytesRef docIDBytes = pointValue.docIDBytes(); + assert docIDBytes.length == Integer.BYTES : "[docIDBytes] must have length [" + Integer.BYTES + "] but was [" + docIDBytes.length + "]"; + out.writeBytes(docIDBytes.bytes, docIDBytes.offset, docIDBytes.length); count++; - assert expectedCount == 0 || count <= expectedCount; + assert expectedCount == 0 || count <= expectedCount : "expectedCount=" + expectedCount + " vs count=" + count; } @Override @@ -75,7 +79,7 @@ public final class OfflinePointWriter implements PointWriter { } protected OfflinePointReader getReader(long start, long length, byte[] reusableBuffer) throws IOException { - assert closed; + assert closed: "point writer is still open and trying to get a reader"; assert start + length <= count: "start=" + start + " length=" + length + " count=" + count; assert expectedCount == 0 || count == expectedCount; return new OfflinePointReader(tempDir, name, packedBytesLength, start, length, reusableBuffer); diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java index c0eaff880de..631d004b052 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java @@ -20,24 +20,19 @@ package org.apache.lucene.util.bkd; import java.io.Closeable; import java.io.IOException; -import org.apache.lucene.util.BytesRef; - /** One pass iterator through all points previously written with a * {@link PointWriter}, abstracting away whether points are read * from (offline) disk or simple arrays in heap. * * @lucene.internal * */ -public abstract class PointReader implements Closeable { +public interface PointReader extends Closeable { /** Returns false once iteration is done, else true. */ - public abstract boolean next() throws IOException; + boolean next() throws IOException; /** Sets the packed value in the provided ByteRef */ - public abstract void packedValue(BytesRef bytesRef); - - /** DocID for this point */ - public abstract int docID(); + PointValue pointValue(); } diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/PointValue.java b/lucene/core/src/java/org/apache/lucene/util/bkd/PointValue.java new file mode 100644 index 00000000000..79c3efa16c1 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/PointValue.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.bkd; + +import org.apache.lucene.util.BytesRef; + +/** + * Represent a dimensional point value written in the BKD tree. + */ +public interface PointValue { + + /** Return the packed values for the dimensions */ + BytesRef packedValue(); + + /** The document id */ + int docID(); + + /** The byte representation of the document id */ + BytesRef docIDBytes(); + +} diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java index 194c8267752..10cc02cdea2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java @@ -20,8 +20,6 @@ package org.apache.lucene.util.bkd; import java.io.Closeable; import java.io.IOException; -import org.apache.lucene.util.BytesRef; - /** Appends many points, and then at the end provides a {@link PointReader} to iterate * those points. This abstracts away whether we write to disk, or use simple arrays * in heap. @@ -29,11 +27,12 @@ import org.apache.lucene.util.BytesRef; * @lucene.internal * */ public interface PointWriter extends Closeable { - /** Add a new point from byte array*/ + + /** Add a new point from the packed value and docId */ void append(byte[] packedValue, int docID) throws IOException; - /** Add a new point from byteRef */ - void append(BytesRef packedValue, int docID) throws IOException; + /** Add a new point from a {@link PointValue} */ + void append(PointValue pointValue) throws IOException; /** Returns a {@link PointReader} iterator to step through all previously added points */ PointReader getReader(long startPoint, long length) throws IOException; diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java index 558b9f2dc8d..c2908258d51 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java @@ -38,15 +38,15 @@ public class TestBKDRadixSelector extends LuceneTestCase { int bytesPerDimensions = Integer.BYTES; int packedLength = dimensions * bytesPerDimensions; PointWriter points = getRandomPointWriter(dir, values, packedLength); - byte[] bytes = new byte[Integer.BYTES]; - NumericUtils.intToSortableBytes(1, bytes, 0); - points.append(bytes, 0); - NumericUtils.intToSortableBytes(2, bytes, 0); - points.append(bytes, 1); - NumericUtils.intToSortableBytes(3, bytes, 0); - points.append(bytes, 2); - NumericUtils.intToSortableBytes(4, bytes, 0); - points.append(bytes, 3); + byte[] value = new byte[packedLength]; + NumericUtils.intToSortableBytes(1, value, 0); + points.append(value, 0); + NumericUtils.intToSortableBytes(2, value, 0); + points.append(value, 1); + NumericUtils.intToSortableBytes(3, value, 0); + points.append(value, 2); + NumericUtils.intToSortableBytes(4, value, 0); + points.append(value, 3); points.close(); PointWriter copy = copyPoints(dir,points, packedLength); verify(dir, copy, dimensions, 0, values, middle, packedLength, bytesPerDimensions, 0); @@ -182,21 +182,17 @@ public class TestBKDRadixSelector extends LuceneTestCase { } private void verify(Directory dir, PointWriter points, int dimensions, long start, long end, long middle, int packedLength, int bytesPerDimensions, int sortedOnHeap) throws IOException{ + BKDRadixSelector radixSelector = new BKDRadixSelector(dimensions, bytesPerDimensions, sortedOnHeap, dir, "test"); + //we check for each dimension for (int splitDim =0; splitDim < dimensions; splitDim++) { - PointWriter copy = copyPoints(dir, points, packedLength); + //We need to make a copy of the data as it is deleted in the process + BKDRadixSelector.PathSlice inputSlice = new BKDRadixSelector.PathSlice(copyPoints(dir, points, packedLength), 0, points.count()); + int commonPrefixLengthInput = getRandomCommonPrefix(inputSlice, bytesPerDimensions, splitDim); BKDRadixSelector.PathSlice[] slices = new BKDRadixSelector.PathSlice[2]; - BKDRadixSelector radixSelector = new BKDRadixSelector(dimensions, bytesPerDimensions, sortedOnHeap, dir, "test"); - BKDRadixSelector.PathSlice copySlice = new BKDRadixSelector.PathSlice(copy, 0, copy.count()); - byte[] pointsMax = getMax(copySlice, bytesPerDimensions, splitDim); - byte[] pointsMin = getMin(copySlice, bytesPerDimensions, splitDim); - int commonPrefixLength = FutureArrays.mismatch(pointsMin, 0, bytesPerDimensions, pointsMax, 0, bytesPerDimensions); - if (commonPrefixLength == -1) { - commonPrefixLength = bytesPerDimensions; - } - int commonPrefixLengthInput = (random().nextBoolean()) ? commonPrefixLength : commonPrefixLength == 0 ? 0 : random().nextInt(commonPrefixLength); - byte[] partitionPoint = radixSelector.select(copySlice, slices, start, end, middle, splitDim, commonPrefixLengthInput); + byte[] partitionPoint = radixSelector.select(inputSlice, slices, start, end, middle, splitDim, commonPrefixLengthInput); assertEquals(middle - start, slices[0].count); assertEquals(end - middle, slices[1].count); + //check that left and right slices contain the correct points byte[] max = getMax(slices[0], bytesPerDimensions, splitDim); byte[] min = getMin(slices[1], bytesPerDimensions, splitDim); int cmp = FutureArrays.compareUnsigned(max, 0, bytesPerDimensions, min, 0, bytesPerDimensions); @@ -213,22 +209,31 @@ public class TestBKDRadixSelector extends LuceneTestCase { points.destroy(); } - private PointWriter copyPoints(Directory dir, PointWriter points, int packedLength) throws IOException { - BytesRef bytesRef = new BytesRef(); + private PointWriter copyPoints(Directory dir, PointWriter points, int packedLength) throws IOException { try (PointWriter copy = getRandomPointWriter(dir, points.count(), packedLength); PointReader reader = points.getReader(0, points.count())) { while (reader.next()) { - reader.packedValue(bytesRef); - copy.append(bytesRef, reader.docID()); + copy.append(reader.pointValue()); } return copy; } } + /** returns a common prefix length equal or lower than the current one */ + private int getRandomCommonPrefix(BKDRadixSelector.PathSlice inputSlice, int bytesPerDimension, int splitDim) throws IOException { + byte[] pointsMax = getMax(inputSlice, bytesPerDimension, splitDim); + byte[] pointsMin = getMin(inputSlice, bytesPerDimension, splitDim); + int commonPrefixLength = FutureArrays.mismatch(pointsMin, 0, bytesPerDimension, pointsMax, 0, bytesPerDimension); + if (commonPrefixLength == -1) { + commonPrefixLength = bytesPerDimension; + } + return (random().nextBoolean()) ? commonPrefixLength : commonPrefixLength == 0 ? 0 : random().nextInt(commonPrefixLength); + } + private PointWriter getRandomPointWriter(Directory dir, long numPoints, int packedBytesLength) throws IOException { if (numPoints < 4096 && random().nextBoolean()) { - return new HeapPointWriter(Math.toIntExact(numPoints), Math.toIntExact(numPoints), packedBytesLength); + return new HeapPointWriter(Math.toIntExact(numPoints), packedBytesLength); } else { return new OfflinePointWriter(dir, "test", packedBytesLength, "data", numPoints); } @@ -249,9 +254,10 @@ public class TestBKDRadixSelector extends LuceneTestCase { Arrays.fill(min, (byte) 0xff); try (PointReader reader = p.writer.getReader(p.start, p.count)) { byte[] value = new byte[bytesPerDimension]; - BytesRef packedValue = new BytesRef(); + while (reader.next()) { - reader.packedValue(packedValue); + PointValue pointValue = reader.pointValue(); + BytesRef packedValue = pointValue.packedValue(); System.arraycopy(packedValue.bytes, packedValue.offset + dimension * bytesPerDimension, value, 0, bytesPerDimension); if (FutureArrays.compareUnsigned(min, 0, bytesPerDimension, value, 0, bytesPerDimension) > 0) { System.arraycopy(value, 0, min, 0, bytesPerDimension); @@ -264,12 +270,12 @@ public class TestBKDRadixSelector extends LuceneTestCase { private int getMinDocId(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { int docID = Integer.MAX_VALUE; try (PointReader reader = p.writer.getReader(p.start, p.count)) { - BytesRef packedValue = new BytesRef(); while (reader.next()) { - reader.packedValue(packedValue); + PointValue pointValue = reader.pointValue(); + BytesRef packedValue = pointValue.packedValue(); int offset = dimension * bytesPerDimension; if (FutureArrays.compareUnsigned(packedValue.bytes, packedValue.offset + offset, packedValue.offset + offset + bytesPerDimension, partitionPoint, 0, bytesPerDimension) == 0) { - int newDocID = reader.docID(); + int newDocID = pointValue.docID(); if (newDocID < docID) { docID = newDocID; } @@ -284,9 +290,9 @@ public class TestBKDRadixSelector extends LuceneTestCase { Arrays.fill(max, (byte) 0); try (PointReader reader = p.writer.getReader(p.start, p.count)) { byte[] value = new byte[bytesPerDimension]; - BytesRef packedValue = new BytesRef(); while (reader.next()) { - reader.packedValue(packedValue); + PointValue pointValue = reader.pointValue(); + BytesRef packedValue = pointValue.packedValue(); System.arraycopy(packedValue.bytes, packedValue.offset + dimension * bytesPerDimension, value, 0, bytesPerDimension); if (FutureArrays.compareUnsigned(max, 0, bytesPerDimension, value, 0, bytesPerDimension) < 0) { System.arraycopy(value, 0, max, 0, bytesPerDimension); @@ -299,12 +305,12 @@ public class TestBKDRadixSelector extends LuceneTestCase { private int getMaxDocId(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension, byte[] partitionPoint) throws IOException { int docID = Integer.MIN_VALUE; try (PointReader reader = p.writer.getReader(p.start, p.count)) { - BytesRef packedValue = new BytesRef(); while (reader.next()) { - reader.packedValue(packedValue); + PointValue pointValue = reader.pointValue(); + BytesRef packedValue = pointValue.packedValue(); int offset = dimension * bytesPerDimension; if (FutureArrays.compareUnsigned(packedValue.bytes, packedValue.offset + offset, packedValue.offset + offset + bytesPerDimension, partitionPoint, 0, bytesPerDimension) == 0) { - int newDocID = reader.docID(); + int newDocID = pointValue.docID(); if (newDocID > docID) { docID = newDocID; }