LUCENE-6825: add low-level support for block-KD trees

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709783 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Michael McCandless 2015-10-21 09:56:42 +00:00
parent d40aef4aca
commit f3f232f775
16 changed files with 2392 additions and 58 deletions

View File

@ -37,6 +37,8 @@ New Features
that can return false to exit from ExitableDirectoryReader wrapping at
the point fields() is called. (yonik)
* LUCENE-6825: Add low-level support for block-KD trees (Mike McCandless)
API Changes
* LUCENE-3312: The API of oal.document was restructured to

View File

@ -186,6 +186,7 @@ public final class ByteBlockPool {
}
}
}
/**
* Advances the pool to its next buffer. This method should be called once
* after the constructor to initialize the pool. In contrast to the

View File

@ -43,7 +43,7 @@ import org.apache.lucene.store.TrackingDirectoryWrapper;
* @lucene.experimental
* @lucene.internal
*/
public final class OfflineSorter {
public class OfflineSorter {
/** Convenience constant for megabytes */
public final static long MB = 1024 * 1024;
@ -237,7 +237,7 @@ public final class OfflineSorter {
TrackingDirectoryWrapper trackingDir = new TrackingDirectoryWrapper(dir);
boolean success = false;
try (ByteSequencesReader is = new ByteSequencesReader(dir.openInput(inputFileName, IOContext.READONCE))) {
try (ByteSequencesReader is = getReader(dir.openInput(inputFileName, IOContext.READONCE))) {
int lineCount;
while ((lineCount = readPartition(is)) > 0) {
@ -284,8 +284,9 @@ public final class OfflineSorter {
protected String sortPartition(TrackingDirectoryWrapper trackingDir) throws IOException {
BytesRefArray data = this.buffer;
try (IndexOutput tempFile = trackingDir.createTempOutput(tempFileNamePrefix, "sort", IOContext.DEFAULT)) {
ByteSequencesWriter out = new ByteSequencesWriter(tempFile);
try (IndexOutput tempFile = trackingDir.createTempOutput(tempFileNamePrefix, "sort", IOContext.DEFAULT);
ByteSequencesWriter out = getWriter(tempFile);) {
BytesRef spare;
long start = System.currentTimeMillis();
@ -320,16 +321,18 @@ public final class OfflineSorter {
String newSegmentName = null;
try (IndexOutput out = trackingDir.createTempOutput(tempFileNamePrefix, "sort", IOContext.DEFAULT)) {
newSegmentName = out.getName();
ByteSequencesWriter writer = new ByteSequencesWriter(out);
try (IndexOutput out = trackingDir.createTempOutput(tempFileNamePrefix, "sort", IOContext.DEFAULT);
ByteSequencesWriter writer = getWriter(out);) {
newSegmentName = out.getName();
// Open streams and read the top for each file
for (int i = 0; i < segments.size(); i++) {
streams[i] = new ByteSequencesReader(dir.openInput(segments.get(i), IOContext.READONCE));
byte[] line = streams[i].read();
assert line != null;
queue.insertWithOverflow(new FileAndTop(i, line));
streams[i] = getReader(dir.openInput(segments.get(i), IOContext.READONCE));
BytesRefBuilder bytes = new BytesRefBuilder();
boolean result = streams[i].read(bytes);
assert result;
queue.insertWithOverflow(new FileAndTop(i, bytes));
}
// Unix utility sort() uses ordered array of files to pick the next line from, updating
@ -363,13 +366,12 @@ public final class OfflineSorter {
/** Read in a single partition of data */
int readPartition(ByteSequencesReader reader) throws IOException {
long start = System.currentTimeMillis();
final BytesRef scratch = new BytesRef();
while ((scratch.bytes = reader.read()) != null) {
scratch.length = scratch.bytes.length;
buffer.append(scratch);
final BytesRefBuilder scratch = new BytesRefBuilder();
while (reader.read(scratch)) {
buffer.append(scratch.get());
// Account for the created objects.
// (buffer slots do not account to buffer size.)
if (ramBufferSize.bytes < bufferBytesUsed.get()) {
if (bufferBytesUsed.get() > ramBufferSize.bytes) {
break;
}
}
@ -381,19 +383,28 @@ public final class OfflineSorter {
final int fd;
final BytesRefBuilder current;
FileAndTop(int fd, byte[] firstLine) {
FileAndTop(int fd, BytesRefBuilder firstLine) {
this.fd = fd;
this.current = new BytesRefBuilder();
this.current.copyBytes(firstLine, 0, firstLine.length);
this.current = firstLine;
}
}
/** Subclasses can override to change how byte sequences are written to disk. */
protected ByteSequencesWriter getWriter(IndexOutput out) throws IOException {
return new ByteSequencesWriter(out);
}
/** Subclasses can override to change how byte sequences are read from disk. */
protected ByteSequencesReader getReader(IndexInput in) throws IOException {
return new ByteSequencesReader(in);
}
/**
* Utility class to emit length-prefixed byte[] entries to an output stream for sorting.
* Complementary to {@link ByteSequencesReader}.
*/
public static class ByteSequencesWriter implements Closeable {
private final IndexOutput out;
protected final IndexOutput out;
/** Constructs a ByteSequencesWriter to the provided DataOutput */
public ByteSequencesWriter(IndexOutput out) {
@ -404,7 +415,7 @@ public final class OfflineSorter {
* Writes a BytesRef.
* @see #write(byte[], int, int)
*/
public void write(BytesRef ref) throws IOException {
public final void write(BytesRef ref) throws IOException {
assert ref != null;
write(ref.bytes, ref.offset, ref.length);
}
@ -413,7 +424,7 @@ public final class OfflineSorter {
* Writes a byte array.
* @see #write(byte[], int, int)
*/
public void write(byte[] bytes) throws IOException {
public final void write(byte[] bytes) throws IOException {
write(bytes, 0, bytes.length);
}
@ -448,7 +459,7 @@ public final class OfflineSorter {
* Complementary to {@link ByteSequencesWriter}.
*/
public static class ByteSequencesReader implements Closeable {
private final IndexInput in;
protected final IndexInput in;
/** Constructs a ByteSequencesReader from the provided IndexInput */
public ByteSequencesReader(IndexInput in) {
@ -477,29 +488,6 @@ public final class OfflineSorter {
return true;
}
/**
* Reads the next entry and returns it if successful.
*
* @see #read(BytesRefBuilder)
*
* @return Returns <code>null</code> if EOF occurred before the next entry
* could be read.
* @throws EOFException if the file ends before the full sequence is read.
*/
public byte[] read() throws IOException {
short length;
try {
length = in.readShort();
} catch (EOFException e) {
return null;
}
assert length >= 0 : "Sanity: sequence length < 0: " + length;
byte[] result = new byte[length];
in.readBytes(result, 0, length);
return result;
}
/**
* Closes the provided {@link IndexInput}.
*/

View File

@ -0,0 +1,250 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
/** Handles intersection of an multi-dimensional shape in byte[] space with a block KD-tree previously written with {@link BKDWriter}.
*
* @lucene.experimental */
public final class BKDReader implements Accountable {
// Packed array of byte[] holding all split values in the full binary tree:
final private byte[] splitPackedValues;
final private long[] leafBlockFPs;
final private int leafNodeOffset;
final int numDims;
final int bytesPerDim;
final IndexInput in;
final int packedBytesLength;
final int maxPointsInLeafNode;
enum Relation {CELL_INSIDE_QUERY, QUERY_CROSSES_CELL, QUERY_OUTSIDE_CELL};
/** We recurse the BKD tree, using a provided instance of this to guide the recursion.
*
* @lucene.experimental */
public interface IntersectVisitor {
/** Called for all docs in a leaf cell that's fully contained by the query. The
* consumer should blindly accept the docID. */
void visit(int docID);
/** Called for all docs in a leaf cell that crosses the query. The consumer
* should scrutinize the packedValue to decide whether to accept it. */
void visit(int docID, byte[] packedValue);
/** Called for non-leaf cells to test how the cell relates to the query, to
* determine how to further recurse down the treer. */
Relation compare(byte[] minPackedValue, byte[] maxPackedValue);
}
/** Caller must pre-seek the provided {@link IndexInput} to the index location that {@link BKDWriter#finish} returned */
public BKDReader(IndexInput in) throws IOException {
CodecUtil.checkHeader(in, BKDWriter.CODEC_NAME, BKDWriter.VERSION_START, BKDWriter.VERSION_START);
numDims = in.readVInt();
maxPointsInLeafNode = in.readVInt();
bytesPerDim = in.readVInt();
packedBytesLength = numDims * bytesPerDim;
// Read index:
int numLeaves = in.readVInt();
leafNodeOffset = numLeaves;
splitPackedValues = new byte[(1+bytesPerDim)*numLeaves];
in.readBytes(splitPackedValues, 0, splitPackedValues.length);
// Tree is fully balanced binary tree, so number of nodes = numLeaves-1, except our nodeIDs are 1-based (splitPackedValues[0] is unused):
leafBlockFPs = new long[numLeaves];
for(int i=0;i<numLeaves;i++) {
leafBlockFPs[i] = in.readVLong();
}
this.in = in;
}
private static final class IntersectState {
final IndexInput in;
final int[] scratchDocIDs;
final byte[] scratchPackedValue;
// Minimum point of the N-dim rect containing the query shape:
final byte[] minPacked;
// Maximum point of the N-dim rect containing the query shape:
final byte[] maxPacked;
final IntersectVisitor visitor;
public IntersectState(IndexInput in, int packedBytesLength,
int maxPointsInLeafNode, byte[] minPacked, byte[] maxPacked,
IntersectVisitor visitor) {
this.in = in;
this.minPacked = minPacked;
this.maxPacked = maxPacked;
this.visitor = visitor;
this.scratchDocIDs = new int[maxPointsInLeafNode];
this.scratchPackedValue = new byte[packedBytesLength];
}
}
public void intersect(IntersectVisitor visitor) throws IOException {
byte[] minPacked = new byte[packedBytesLength];
byte[] maxPacked = new byte[packedBytesLength];
Arrays.fill(maxPacked, (byte) 0xff);
intersect(minPacked, maxPacked, visitor);
}
public void intersect(byte[] minPacked, byte[] maxPacked, IntersectVisitor visitor) throws IOException {
IntersectState state = new IntersectState(in.clone(), packedBytesLength,
maxPointsInLeafNode, minPacked, maxPacked,
visitor);
byte[] rootMinPacked = new byte[packedBytesLength];
byte[] rootMaxPacked = new byte[packedBytesLength];
Arrays.fill(rootMaxPacked, (byte) 0xff);
intersect(state, 1, rootMinPacked, rootMaxPacked);
}
/** Fast path: this is called when the query box fully encompasses all cells under this node. */
private void addAll(IntersectState state, int nodeID) throws IOException {
//System.out.println("R: addAll nodeID=" + nodeID);
if (nodeID >= leafNodeOffset) {
//System.out.println("R: leaf");
// Leaf node
state.in.seek(leafBlockFPs[nodeID-leafNodeOffset]);
// How many points are stored in this leaf cell:
int count = state.in.readVInt();
// TODO: especially for the 1D case, this was a decent speedup, because caller could know it should budget for around XXX docs:
//state.docs.grow(count);
int docID = 0;
for(int i=0;i<count;i++) {
docID += state.in.readVInt();
state.visitor.visit(docID);
}
} else {
addAll(state, 2*nodeID);
addAll(state, 2*nodeID+1);
}
}
private void intersect(IntersectState state,
int nodeID,
byte[] cellMinPacked, byte[] cellMaxPacked)
throws IOException {
//System.out.println("\nR: intersect nodeID=" + nodeID + " cellMin=" + BKDUtil.bytesToInt(cellMinPacked, 0) + " cellMax=" + BKDUtil.bytesToInt(cellMaxPacked, 0));
// Optimization: only check the visitor when the current cell does not fully contain the bbox. E.g. if the
// query is a small area around London, UK, most of the high nodes in the BKD tree as we recurse will fully
// contain the query, so we quickly recurse down until the nodes cross the query:
boolean cellContainsQuery = BKDUtil.contains(bytesPerDim,
cellMinPacked, cellMaxPacked,
state.minPacked, state.maxPacked);
//System.out.println("R: cellContainsQuery=" + cellContainsQuery);
if (cellContainsQuery == false) {
Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
//System.out.println("R: relation=" + r);
if (r == Relation.QUERY_OUTSIDE_CELL) {
// This cell is fully outside of the query shape: stop recursing
return;
} else if (r == Relation.CELL_INSIDE_QUERY) {
// This cell is fully inside of the query shape: recursively add all points in this cell without filtering
addAll(state, nodeID);
return;
} else {
// The cell crosses the shape boundary, so we fall through and do full filtering
}
}
if (nodeID >= leafNodeOffset) {
// Leaf node; scan and filter all points in this block:
//System.out.println(" intersect leaf nodeID=" + nodeID + " vs leafNodeOffset=" + leafNodeOffset + " fp=" + leafBlockFPs[nodeID-leafNodeOffset]);
state.in.seek(leafBlockFPs[nodeID-leafNodeOffset]);
// How many points are stored in this leaf cell:
int count = state.in.readVInt();
// TODO: we could maybe pollute the IntersectVisitor API with a "grow" method if this maybe helps perf
// enough (it did before, esp. for the 1D case):
//state.docs.grow(count);
int docID = 0;
for(int i=0;i<count;i++) {
docID += state.in.readVInt();
state.scratchDocIDs[i] = docID;
}
// Again, this time reading values and checking with the visitor
for(int i=0;i<count;i++) {
state.in.readBytes(state.scratchPackedValue, 0, state.scratchPackedValue.length);
state.visitor.visit(state.scratchDocIDs[i], state.scratchPackedValue);
}
} else {
// Non-leaf node: recurse on the split left and right nodes
int address = nodeID * (bytesPerDim+1);
int splitDim = splitPackedValues[address] & 0xff;
assert splitDim < numDims;
// TODO: can we alloc & reuse this up front?
byte[] splitValue = new byte[bytesPerDim];
System.arraycopy(splitPackedValues, address+1, splitValue, 0, bytesPerDim);
// TODO: can we alloc & reuse this up front?
byte[] splitPackedValue = new byte[packedBytesLength];
if (BKDUtil.compare(bytesPerDim, state.minPacked, splitDim, splitValue, 0) <= 0) {
// The query bbox overlaps our left cell, so we must recurse:
System.arraycopy(state.maxPacked, 0, splitPackedValue, 0, packedBytesLength);
System.arraycopy(splitValue, 0, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
intersect(state,
2*nodeID,
cellMinPacked, splitPackedValue);
}
if (BKDUtil.compare(bytesPerDim, state.maxPacked, splitDim, splitValue, 0) >= 0) {
// The query bbox overlaps our left cell, so we must recurse:
System.arraycopy(state.minPacked, 0, splitPackedValue, 0, packedBytesLength);
System.arraycopy(splitValue, 0, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
intersect(state,
2*nodeID+1,
splitPackedValue, cellMaxPacked);
}
}
}
@Override
public long ramBytesUsed() {
return splitPackedValues.length +
leafBlockFPs.length * RamUsageEstimator.NUM_BYTES_LONG;
}
}

View File

@ -0,0 +1,131 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.math.BigInteger;
import java.util.Arrays;
/** Utility methods to handle N-dimensional packed byte[] as if they were numbers! */
final class BKDUtil {
private BKDUtil() {
// No instance
}
/** result = a - b, where a >= b */
public static void subtract(int bytesPerDim, int dim, byte[] a, byte[] b, byte[] result) {
int start = dim * bytesPerDim;
int end = start + bytesPerDim;
int borrow = 0;
for(int i=end-1;i>=start;i--) {
int diff = (a[i]&0xff) - (b[i]&0xff) - borrow;
if (diff < 0) {
diff += 256;
borrow = 1;
} else {
borrow = 0;
}
result[i-start] = (byte) diff;
}
if (borrow != 0) {
throw new IllegalArgumentException("a < b?");
}
}
/** Returns positive int if a > b, negative int if a < b and 0 if a == b */
public static int compare(int bytesPerDim, byte[] a, int aIndex, byte[] b, int bIndex) {
for(int i=0;i<bytesPerDim;i++) {
int cmp = (a[aIndex*bytesPerDim+i]&0xff) - (b[bIndex*bytesPerDim+i]&0xff);
if (cmp != 0) {
return cmp;
}
}
return 0;
}
/** Returns true if N-dim rect A contains N-dim rect B */
public static boolean contains(int bytesPerDim,
byte[] minPackedA, byte[] maxPackedA,
byte[] minPackedB, byte[] maxPackedB) {
int dims = minPackedA.length / bytesPerDim;
for(int dim=0;dim<dims;dim++) {
if (compare(bytesPerDim, minPackedA, dim, minPackedB, dim) > 0) {
return false;
}
if (compare(bytesPerDim, maxPackedA, dim, maxPackedB, dim) < 0) {
return false;
}
}
return true;
}
static void intToBytes(int x, byte[] dest, int index) {
// Flip the sign bit, so negative ints sort before positive ints correctly:
x ^= 0x80000000;
for(int i=0;i<4;i++) {
dest[4*index+i] = (byte) (x >> 24-i*8);
}
}
static int bytesToInt(byte[] src, int index) {
int x = 0;
for(int i=0;i<4;i++) {
x |= (src[4*index+i] & 0xff) << (24-i*8);
}
// Re-flip the sign bit to restore the original value:
return x ^ 0x80000000;
}
static void sortableBigIntBytes(byte[] bytes) {
bytes[0] ^= 0x80;
for(int i=1;i<bytes.length;i++) {
bytes[i] ^= 0;
}
}
static void bigIntToBytes(BigInteger bigInt, byte[] result, int dim, int numBytesPerDim) {
byte[] bigIntBytes = bigInt.toByteArray();
byte[] fullBigIntBytes;
if (bigIntBytes.length < numBytesPerDim) {
fullBigIntBytes = new byte[numBytesPerDim];
System.arraycopy(bigIntBytes, 0, fullBigIntBytes, numBytesPerDim-bigIntBytes.length, bigIntBytes.length);
if ((bigIntBytes[0] & 0x80) != 0) {
// sign extend
Arrays.fill(fullBigIntBytes, 0, numBytesPerDim-bigIntBytes.length, (byte) 0xff);
}
} else {
assert bigIntBytes.length == numBytesPerDim;
fullBigIntBytes = bigIntBytes;
}
sortableBigIntBytes(fullBigIntBytes);
System.arraycopy(fullBigIntBytes, 0, result, dim * numBytesPerDim, numBytesPerDim);
assert bytesToBigInt(result, dim, numBytesPerDim).equals(bigInt): "bigInt=" + bigInt + " converted=" + bytesToBigInt(result, dim, numBytesPerDim);
}
static BigInteger bytesToBigInt(byte[] bytes, int dim, int numBytesPerDim) {
byte[] bigIntBytes = new byte[numBytesPerDim];
System.arraycopy(bytes, dim*numBytesPerDim, bigIntBytes, 0, numBytesPerDim);
sortableBigIntBytes(bigIntBytes);
return new BigInteger(bigIntBytes);
}
}

View File

@ -0,0 +1,743 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.TrackingDirectoryWrapper;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.apache.lucene.util.LongBitSet;
import org.apache.lucene.util.OfflineSorter.ByteSequencesWriter;
import org.apache.lucene.util.OfflineSorter;
import org.apache.lucene.util.RamUsageEstimator;
// TODO
// - the compression is somewhat stupid now (delta vInt for 1024 docIDs, no compression for the byte[] values even though they have high locality)
// - allow variable length byte[] (across docs and dims), but this is quite a bit more hairy
// - we could also index "auto-prefix terms" here, and use better compression, and maybe only use for the "fully contained" case so we'd
// only index docIDs
// - the index could be efficiently encoded as an FST, so we don't have wasteful
// (monotonic) long[] leafBlockFPs; or we could use MonotonicLongValues ... but then
// the index is already plenty small: 60M OSM points --> 1.1 MB with 128 points
// per leaf, and you can reduce that by putting more points per leaf
// - we could use threads while building; the higher nodes are very parallelizable
/** Recursively builds a block KD-tree to assign all incoming points in N-dim space to smaller
* and smaller N-dim rectangles (cells) until the number of points in a given
* rectangle is &lt;= <code>maxPointsInLeafNode</code>. The tree is
* fully balanced, which means the leaf nodes will have between 50% and 100% of
* the requested <code>maxPointsInLeafNode</code>. Values that fall exactly
* on a cell boundary may be in either cell.
*
* <p>The number of dimensions can be 1 to 255, but every byte[] value is fixed length.
*
* <p>
* See <a href="https://www.cs.duke.edu/~pankaj/publications/papers/bkd-sstd.pdf">this paper</a> for details.
*
* <p>This consumes heap during writing: it allocates a <code>LongBitSet(numPoints)</code>,
* and then uses up to the specified {@code maxMBSortInHeap} heap space for writing.
*
* <p>
* <b>NOTE</b>: This can write at most Integer.MAX_VALUE * <code>maxPointsInLeafNode</code> total points, and
*
* @lucene.experimental */
public final class BKDWriter implements Closeable {
static final String CODEC_NAME = "BKD";
static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
/** How many bytes each docs takes in the fixed-width offline format */
private final int bytesPerDoc;
public static final int DEFAULT_MAX_POINTS_IN_LEAF_NODE = 1024;
public static final float DEFAULT_MAX_MB_SORT_IN_HEAP = 16.0f;
/** Maximum number of dimensions */
public static final int MAX_DIMS = 15;
/** How many dimensions we are indexing */
final int numDims;
/** How many bytes each value in each dimension takes. */
final int bytesPerDim;
/** numDims * bytesPerDim */
final int packedBytesLength;
final TrackingDirectoryWrapper tempDir;
final String tempFileNamePrefix;
final byte[] scratchDiff;
final byte[] scratchPackedValue;
final byte[] scratch1;
final byte[] scratch2;
private OfflinePointWriter offlinePointWriter;
private HeapPointWriter heapPointWriter;
private IndexOutput tempInput;
private final int maxPointsInLeafNode;
private final int maxPointsSortInHeap;
private long pointCount;
public BKDWriter(Directory tempDir, String tempFileNamePrefix, int numDims, int bytesPerDim) throws IOException {
this(tempDir, tempFileNamePrefix, numDims, bytesPerDim, DEFAULT_MAX_POINTS_IN_LEAF_NODE, DEFAULT_MAX_MB_SORT_IN_HEAP);
}
public BKDWriter(Directory tempDir, String tempFileNamePrefix, int numDims, int bytesPerDim, int maxPointsInLeafNode, double maxMBSortInHeap) throws IOException {
verifyParams(numDims, maxPointsInLeafNode, maxMBSortInHeap);
// We use tracking dir to deal with removing files on exception, so each place that
// creates temp files doesn't need crazy try/finally/sucess logic:
this.tempDir = new TrackingDirectoryWrapper(tempDir);
this.tempFileNamePrefix = tempFileNamePrefix;
this.maxPointsInLeafNode = maxPointsInLeafNode;
this.numDims = numDims;
this.bytesPerDim = bytesPerDim;
packedBytesLength = numDims * bytesPerDim;
scratchDiff = new byte[bytesPerDim];
scratchPackedValue = new byte[packedBytesLength];
scratch1 = new byte[packedBytesLength];
scratch2 = new byte[packedBytesLength];
// dimensional values (numDims * bytesPerDim) + ord (long) + docID (int)
bytesPerDoc = packedBytesLength + RamUsageEstimator.NUM_BYTES_LONG + RamUsageEstimator.NUM_BYTES_INT;
// As we recurse, we compute temporary partitions of the data, halving the
// number of points at each recursion. Once there are few enough points,
// we can switch to sorting in heap instead of offline (on disk). At any
// time in the recursion, we hold the number of points at that level, plus
// all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X
// what that level would consume, so we multiply by 0.5 to convert from
// bytes to points here. Each dimension has its own sorted partition, so
// we must divide by numDims as wel.
maxPointsSortInHeap = (int) (0.5 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDims));
// Finally, we must be able to hold at least the leaf node in heap during build:
if (maxPointsSortInHeap < maxPointsInLeafNode) {
throw new IllegalArgumentException("maxMBSortInHeap=" + maxMBSortInHeap + " only allows for maxPointsSortInHeap=" + maxPointsSortInHeap + ", but this is less than maxPointsInLeafNode=" + maxPointsInLeafNode + "; either increase maxMBSortInHeap or decrease maxPointsInLeafNode");
}
// We write first maxPointsSortInHeap in heap, then cutover to offline for additional points:
heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength);
}
public static void verifyParams(int numDims, int maxPointsInLeafNode, double maxMBSortInHeap) {
// We encode dim in a single byte in the splitPackedValues, but we only expose 4 bits for it now, in case we want to use
// remaining 4 bits for another purpose later
if (numDims < 1 || numDims > MAX_DIMS) {
throw new IllegalArgumentException("numDims must be 1 .. " + MAX_DIMS + " (got: " + numDims + ")");
}
if (maxPointsInLeafNode <= 0) {
throw new IllegalArgumentException("maxPointsInLeafNode must be > 0; got " + maxPointsInLeafNode);
}
if (maxPointsInLeafNode > ArrayUtil.MAX_ARRAY_LENGTH) {
throw new IllegalArgumentException("maxPointsInLeafNode must be <= ArrayUtil.MAX_ARRAY_LENGTH (= " + ArrayUtil.MAX_ARRAY_LENGTH + "); got " + maxPointsInLeafNode);
}
if (maxMBSortInHeap < 0.0) {
throw new IllegalArgumentException("maxMBSortInHeap must be >= 0.0 (got: " + maxMBSortInHeap + ")");
}
}
/** If the current segment has too many points then we switchover to temp files / offline sort. */
private void switchToOffline() throws IOException {
// For each .add we just append to this input file, then in .finish we sort this input and resursively build the tree:
offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength);
PointReader reader = heapPointWriter.getReader(0);
for(int i=0;i<pointCount;i++) {
boolean hasNext = reader.next();
assert hasNext;
offlinePointWriter.append(reader.packedValue(), i, heapPointWriter.docIDs[i]);
}
heapPointWriter = null;
tempInput = offlinePointWriter.out;
}
public void add(byte[] packedValue, int docID) throws IOException {
if (packedValue.length != packedBytesLength) {
throw new IllegalArgumentException("packedValue should be length=" + packedBytesLength + " (got: " + packedValue.length + ")");
}
if (pointCount >= maxPointsSortInHeap) {
if (offlinePointWriter == null) {
switchToOffline();
}
offlinePointWriter.append(packedValue, pointCount, docID);
} else {
// Not too many points added yet, continue using heap:
heapPointWriter.append(packedValue, pointCount, docID);
}
pointCount++;
}
// TODO: if we fixed each partition step to just record the file offset at the "split point", we could probably handle variable length
// encoding and not have our own ByteSequencesReader/Writer
/** If dim=-1 we sort by docID, else by that dim. */
private void sortHeapPointWriter(final HeapPointWriter writer, int start, int length, int dim) {
assert pointCount < Integer.MAX_VALUE;
// All buffered points are still in heap; just do in-place sort:
new InPlaceMergeSorter() {
@Override
protected void swap(int i, int j) {
int docID = writer.docIDs[i];
writer.docIDs[i] = writer.docIDs[j];
writer.docIDs[j] = docID;
long ord = writer.ords[i];
writer.ords[i] = writer.ords[j];
writer.ords[j] = ord;
// scratch1 = values[i]
writer.readPackedValue(i, scratch1);
// scratch2 = values[j]
writer.readPackedValue(j, scratch2);
// values[i] = scratch2
writer.writePackedValue(i, scratch2);
// values[j] = scratch1
writer.writePackedValue(j, scratch1);
}
@Override
protected int compare(int i, int j) {
if (dim != -1) {
writer.readPackedValue(i, scratch1);
writer.readPackedValue(j, scratch2);
int cmp = BKDUtil.compare(bytesPerDim, scratch1, dim, scratch2, dim);
if (cmp != 0) {
return cmp;
}
}
// Tie-break
int cmp = Integer.compare(writer.docIDs[i], writer.docIDs[j]);
if (cmp != 0) {
return cmp;
}
return Long.compare(writer.ords[i], writer.ords[j]);
}
}.sort(start, start+length);
}
private PointWriter sort(int dim) throws IOException {
if (heapPointWriter != null) {
assert tempInput == null;
// We never spilled the incoming points to disk, so now we sort in heap:
HeapPointWriter sorted;
if (dim == 0) {
// First dim can re-use the current heap writer
sorted = heapPointWriter;
} else {
// Subsequent dims need a private copy
sorted = new HeapPointWriter((int) pointCount, (int) pointCount, packedBytesLength);
sorted.copyFrom(heapPointWriter);
}
sortHeapPointWriter(sorted, 0, (int) pointCount, dim);
sorted.close();
return sorted;
} else {
// Offline sort:
assert tempInput != null;
final ByteArrayDataInput reader = new ByteArrayDataInput();
Comparator<BytesRef> cmp = new Comparator<BytesRef>() {
private final ByteArrayDataInput readerB = new ByteArrayDataInput();
@Override
public int compare(BytesRef a, BytesRef b) {
reader.reset(a.bytes, a.offset, a.length);
reader.readBytes(scratch1, 0, scratch1.length);
final int docIDA = reader.readVInt();
final long ordA = reader.readVLong();
reader.reset(b.bytes, b.offset, b.length);
reader.readBytes(scratch2, 0, scratch2.length);
final int docIDB = reader.readVInt();
final long ordB = reader.readVLong();
int cmp = BKDUtil.compare(bytesPerDim, scratch1, dim, scratch2, dim);
if (cmp != 0) {
return cmp;
}
// Tie-break
cmp = Integer.compare(docIDA, docIDB);
if (cmp != 0) {
return cmp;
}
return Long.compare(ordA, ordB);
}
};
// TODO: this is sort of sneaky way to get the final OfflinePointWriter from OfflineSorter:
IndexOutput[] lastWriter = new IndexOutput[1];
OfflineSorter sorter = new OfflineSorter(tempDir, tempFileNamePrefix, cmp) {
/** We write/read fixed-byte-width file that {@link OfflinePointReader} can read. */
@Override
protected ByteSequencesWriter getWriter(IndexOutput out) {
lastWriter[0] = out;
return new ByteSequencesWriter(out) {
@Override
public void write(byte[] bytes, int off, int len) throws IOException {
if (len != bytesPerDoc) {
throw new IllegalArgumentException("len=" + len + " bytesPerDoc=" + bytesPerDoc);
}
out.writeBytes(bytes, off, len);
}
};
}
/** We write/read fixed-byte-width file that {@link OfflinePointReader} can read. */
@Override
protected ByteSequencesReader getReader(IndexInput in) throws IOException {
return new ByteSequencesReader(in) {
@Override
public boolean read(BytesRefBuilder ref) throws IOException {
ref.grow(bytesPerDoc);
try {
in.readBytes(ref.bytes(), 0, bytesPerDoc);
} catch (EOFException eofe) {
return false;
}
ref.setLength(bytesPerDoc);
return true;
}
};
}
};
sorter.sort(tempInput.getName());
assert lastWriter[0] != null;
return new OfflinePointWriter(tempDir, lastWriter[0], packedBytesLength, pointCount);
}
}
/** Writes the BKD tree to the provided {@link IndexOutput} and returns the file offset where index was written. */
public long finish(IndexOutput out) throws IOException {
//System.out.println("\nBKDTreeWriter.finish pointCount=" + pointCount + " out=" + out + " heapWriter=" + heapWriter);
// TODO: specialize the 1D case? it's much faster at indexing time (no partitioning on recruse...)
if (offlinePointWriter != null) {
offlinePointWriter.close();
}
LongBitSet ordBitSet = new LongBitSet(pointCount);
long countPerLeaf = pointCount;
long innerNodeCount = 1;
while (countPerLeaf > maxPointsInLeafNode) {
countPerLeaf = (countPerLeaf+1)/2;
innerNodeCount *= 2;
}
//System.out.println("innerNodeCount=" + innerNodeCount);
if (1+2*innerNodeCount >= Integer.MAX_VALUE) {
throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex");
}
innerNodeCount--;
int numLeaves = (int) (innerNodeCount+1);
// NOTE: we could save the 1+ here, to use a bit less heap at search time, but then we'd need a somewhat costly check at each
// step of the recursion to recompute the split dim:
// Indexed by nodeID, but first (root) nodeID is 1. We do 1+ because the lead byte at each recursion says which dim we split on.
byte[] splitPackedValues = new byte[Math.toIntExact(numLeaves*(1+bytesPerDim))];
// +1 because leaf count is power of 2 (e.g. 8), and innerNodeCount is power of 2 minus 1 (e.g. 7)
long[] leafBlockFPs = new long[numLeaves];
// Make sure the math above "worked":
assert pointCount / numLeaves <= maxPointsInLeafNode: "pointCount=" + pointCount + " numLeaves=" + numLeaves + " maxPointsInLeafNode=" + maxPointsInLeafNode;
// Sort all docs once by each dimension:
PathSlice[] sortedPointWriters = new PathSlice[numDims];
byte[] minPacked = new byte[packedBytesLength];
byte[] maxPacked = new byte[packedBytesLength];
Arrays.fill(maxPacked, (byte) 0xff);
boolean success = false;
try {
for(int dim=0;dim<numDims;dim++) {
sortedPointWriters[dim] = new PathSlice(sort(dim), 0, pointCount);
}
if (tempInput != null) {
tempDir.deleteFile(tempInput.getName());
tempInput = null;
} else {
assert heapPointWriter != null;
heapPointWriter = null;
}
build(1, numLeaves, sortedPointWriters,
ordBitSet, out,
minPacked, maxPacked,
splitPackedValues,
leafBlockFPs);
for(PathSlice slice : sortedPointWriters) {
slice.writer.destroy();
}
// If no exception, we should have cleaned everything up:
assert tempDir.getCreatedFiles().isEmpty();
success = true;
} finally {
if (success == false) {
IOUtils.deleteFilesIgnoringExceptions(tempDir, tempDir.getCreatedFiles());
}
}
//System.out.println("Total nodes: " + innerNodeCount);
// Write index:
long indexFP = out.getFilePointer();
CodecUtil.writeHeader(out, CODEC_NAME, VERSION_CURRENT);
out.writeVInt(numDims);
out.writeVInt(maxPointsInLeafNode);
out.writeVInt(bytesPerDim);
out.writeVInt(numLeaves);
// NOTE: splitPackedValues[0] is unused, because nodeID is 1-based:
out.writeBytes(splitPackedValues, 0, splitPackedValues.length);
for (int i=0;i<leafBlockFPs.length;i++) {
out.writeVLong(leafBlockFPs[i]);
}
return indexFP;
}
@Override
public void close() throws IOException {
if (tempInput != null) {
// NOTE: this should only happen on exception, e.g. caller calls close w/o calling finish:
try {
tempInput.close();
} finally {
tempDir.deleteFile(tempInput.getName());
tempInput = null;
}
}
}
/** Sliced reference to points in an OfflineSorter.ByteSequencesWriter file. */
private static final class PathSlice {
final PointWriter writer;
final long start;
final long count;
public PathSlice(PointWriter writer, long start, long count) {
this.writer = writer;
this.start = start;
this.count = count;
}
@Override
public String toString() {
return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
}
}
/** Marks bits for the ords (points) that belong in the right sub tree (those docs that have values >= the splitValue). */
private byte[] markRightTree(long rightCount, int splitDim, PathSlice source, LongBitSet ordBitSet) throws IOException {
// Now we mark ords that fall into the right half, so we can partition on all other dims that are not the split dim:
assert ordBitSet.cardinality() == 0: "cardinality=" + ordBitSet.cardinality();
// Read the split value, then mark all ords in the right tree (larger than the split value):
try (PointReader reader = source.writer.getReader(source.start + source.count - rightCount)) {
boolean result = reader.next();
assert result;
System.arraycopy(reader.packedValue(), splitDim*bytesPerDim, scratch1, 0, bytesPerDim);
ordBitSet.set(reader.ord());
// Start at 1 because we already did the first value above (so we could keep the split value):
for(int i=1;i<rightCount;i++) {
result = reader.next();
assert result;
ordBitSet.set(reader.ord());
}
assert rightCount == ordBitSet.cardinality(): "rightCount=" + rightCount + " cardinality=" + ordBitSet.cardinality();
}
return scratch1;
}
/** Called only in assert */
private boolean valueInBounds(byte[] packedValue, byte[] minPackedValue, byte[] maxPackedValue) {
for(int dim=0;dim<numDims;dim++) {
if (BKDUtil.compare(bytesPerDim, packedValue, dim, minPackedValue, dim) < 0) {
return false;
}
if (BKDUtil.compare(bytesPerDim, packedValue, dim, maxPackedValue, dim) > 0) {
return false;
}
}
return true;
}
// TODO: make this protected when we want to subclass to play with different splitting criteria
private int split(byte[] minPackedValue, byte[] maxPackedValue) {
// Find which dim has the largest span so we can split on it:
int splitDim = -1;
for(int dim=0;dim<numDims;dim++) {
BKDUtil.subtract(bytesPerDim, dim, maxPackedValue, minPackedValue, scratchDiff);
if (splitDim == -1 || BKDUtil.compare(bytesPerDim, scratchDiff, 0, scratch1, 0) > 0) {
System.arraycopy(scratchDiff, 0, scratch1, 0, bytesPerDim);
splitDim = dim;
}
}
return splitDim;
}
/** Only called in the 1D case, to pull a partition back into heap once
* the point count is low enough while recursing. */
private PathSlice switchToHeap(PathSlice source) throws IOException {
int count = Math.toIntExact(source.count);
try (
PointWriter writer = new HeapPointWriter(count, count, packedBytesLength);
PointReader reader = source.writer.getReader(source.start);
) {
for(int i=0;i<count;i++) {
boolean hasNext = reader.next();
assert hasNext;
writer.append(reader.packedValue(), reader.ord(), reader.docID());
}
return new PathSlice(writer, 0, count);
}
}
/** The array (sized numDims) of PathSlice describe the cell we have currently recursed to. */
private void build(int nodeID, int leafNodeOffset,
PathSlice[] slices,
LongBitSet ordBitSet,
IndexOutput out,
byte[] minPackedValue, byte[] maxPackedValue,
byte[] splitPackedValues,
long[] leafBlockFPs) throws IOException {
for(PathSlice slice : slices) {
assert slice.count == slices[0].count;
}
if (numDims == 1 && slices[0].writer instanceof OfflinePointWriter && slices[0].count <= maxPointsSortInHeap) {
// Special case for 1D, to cutover to heap once we recurse deeply enough:
slices[0] = switchToHeap(slices[0]);
}
if (nodeID >= leafNodeOffset) {
// Leaf node: write block
PathSlice source = slices[0];
if (source.writer instanceof HeapPointWriter == false) {
// Adversarial cases can cause this, e.g. very lopsided data, all equal points
source = switchToHeap(source);
}
// We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
HeapPointWriter heapSource = (HeapPointWriter) source.writer;
// Sort by docID in the leaf so we can delta-vInt encode:
sortHeapPointWriter(heapSource, Math.toIntExact(source.start), Math.toIntExact(source.count), -1);
int lastDocID = 0;
// Save the block file pointer:
leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer();
out.writeVInt(Math.toIntExact(source.count));
// Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
// loading the values:
for (int i=0;i<source.count;i++) {
int docID = heapSource.docIDs[Math.toIntExact(source.start + i)];
out.writeVInt(docID - lastDocID);
lastDocID = docID;
}
// TODO: we should delta compress / only write suffix bytes, like terms dict (the values will all be "close together" since we are at
// a leaf cell):
// Now write the full values:
for (int i=0;i<source.count;i++) {
// TODO: we could do bulk copying here, avoiding the intermediate copy:
heapSource.readPackedValue(Math.toIntExact(source.start + i), scratchPackedValue);
// Make sure this value does in fact fall within this leaf cell:
assert valueInBounds(scratchPackedValue, minPackedValue, maxPackedValue);
out.writeBytes(scratchPackedValue, 0, scratchPackedValue.length);
}
} else {
// Inner node: partition/recurse
int splitDim = split(minPackedValue, maxPackedValue);
PathSlice source = slices[splitDim];
assert nodeID < splitPackedValues.length: "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
// How many points will be in the left tree:
long rightCount = source.count / 2;
long leftCount = source.count - rightCount;
byte[] splitValue = markRightTree(rightCount, splitDim, source, ordBitSet);
int address = nodeID * (1+bytesPerDim);
splitPackedValues[address] = (byte) splitDim;
System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim);
// Partition all PathSlice that are not the split dim into sorted left and right sets, so we can recurse:
PathSlice[] leftSlices = new PathSlice[numDims];
PathSlice[] rightSlices = new PathSlice[numDims];
byte[] minSplitPackedValue = new byte[packedBytesLength];
System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, packedBytesLength);
byte[] maxSplitPackedValue = new byte[packedBytesLength];
System.arraycopy(maxPackedValue, 0, maxSplitPackedValue, 0, packedBytesLength);
for(int dim=0;dim<numDims;dim++) {
if (dim == splitDim) {
// No need to partition on this dim since it's a simple slice of the incoming already sorted slice.
leftSlices[dim] = new PathSlice(source.writer, source.start, leftCount);
rightSlices[dim] = new PathSlice(source.writer, source.start + leftCount, rightCount);
System.arraycopy(splitValue, 0, minSplitPackedValue, dim*bytesPerDim, bytesPerDim);
System.arraycopy(splitValue, 0, maxSplitPackedValue, dim*bytesPerDim, bytesPerDim);
continue;
}
try (PointWriter leftPointWriter = getPointWriter(leftCount);
PointWriter rightPointWriter = getPointWriter(source.count - leftCount);
PointReader reader = slices[dim].writer.getReader(slices[dim].start);) {
// Partition this source according to how the splitDim split the values:
int nextRightCount = 0;
for (int i=0;i<source.count;i++) {
boolean result = reader.next();
assert result;
byte[] packedValue = reader.packedValue();
long ord = reader.ord();
int docID = reader.docID();
if (ordBitSet.get(ord)) {
rightPointWriter.append(packedValue, ord, docID);
nextRightCount++;
} else {
leftPointWriter.append(packedValue, ord, docID);
}
}
leftSlices[dim] = new PathSlice(leftPointWriter, 0, leftCount);
rightSlices[dim] = new PathSlice(rightPointWriter, 0, rightCount);
assert rightCount == nextRightCount: "rightCount=" + rightCount + " nextRightCount=" + nextRightCount;
}
}
ordBitSet.clear(0, pointCount);
// Recurse on left tree:
build(2*nodeID, leafNodeOffset, leftSlices,
ordBitSet, out,
minPackedValue, maxSplitPackedValue,
splitPackedValues, leafBlockFPs);
for(int dim=0;dim<numDims;dim++) {
// Don't destroy the dim we split on because we just re-used what our caller above gave us for that dim:
if (dim != splitDim) {
leftSlices[dim].writer.destroy();
}
}
// TODO: we could "tail recurse" here? have our parent discard its refs as we recurse right?
// Recurse on right tree:
build(2*nodeID+1, leafNodeOffset, rightSlices,
ordBitSet, out,
minSplitPackedValue, maxPackedValue,
splitPackedValues, leafBlockFPs);
for(int dim=0;dim<numDims;dim++) {
// Don't destroy the dim we split on because we just re-used what our caller above gave us for that dim:
if (dim != splitDim) {
rightSlices[dim].writer.destroy();
}
}
}
}
PointWriter getPointWriter(long count) throws IOException {
if (count <= maxPointsSortInHeap) {
int size = Math.toIntExact(count);
return new HeapPointWriter(size, size, packedBytesLength);
} else {
return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength);
}
}
}

View File

@ -0,0 +1,85 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.util.List;
import org.apache.lucene.util.PagedBytes;
final class HeapPointReader implements PointReader {
private int curRead;
final List<byte[]> blocks;
final int valuesPerBlock;
final int packedBytesLength;
final long[] ords;
final int[] docIDs;
final int end;
final byte[] scratch;
HeapPointReader(List<byte[]> blocks, int valuesPerBlock, int packedBytesLength, long[] ords, int[] docIDs, int start, int end) {
this.blocks = blocks;
this.valuesPerBlock = valuesPerBlock;
this.ords = ords;
this.docIDs = docIDs;
curRead = start-1;
this.end = end;
this.packedBytesLength = packedBytesLength;
scratch = new byte[packedBytesLength];
}
void writePackedValue(int index, byte[] bytes) {
int block = index / valuesPerBlock;
int blockIndex = index % valuesPerBlock;
while (blocks.size() <= block) {
blocks.add(new byte[valuesPerBlock*packedBytesLength]);
}
System.arraycopy(bytes, 0, blocks.get(blockIndex), blockIndex * packedBytesLength, packedBytesLength);
}
void readPackedValue(int index, byte[] bytes) {
int block = index / valuesPerBlock;
int blockIndex = index % valuesPerBlock;
System.arraycopy(blocks.get(block), blockIndex * packedBytesLength, bytes, 0, packedBytesLength);
}
@Override
public boolean next() {
curRead++;
return curRead < end;
}
@Override
public byte[] packedValue() {
readPackedValue(curRead, scratch);
return scratch;
}
@Override
public int docID() {
return docIDs[curRead];
}
@Override
public long ord() {
return ords[curRead];
}
@Override
public void close() {
}
}

View File

@ -0,0 +1,126 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
final class HeapPointWriter implements PointWriter {
int[] docIDs;
long[] ords;
private int nextWrite;
private boolean closed;
final int maxSize;
final int valuesPerBlock;
final int packedBytesLength;
// NOTE: can't use ByteBlockPool because we need random-write access when sorting in heap
final List<byte[]> blocks = new ArrayList<>();
public HeapPointWriter(int initSize, int maxSize, int packedBytesLength) {
docIDs = new int[initSize];
ords = new long[initSize];
this.maxSize = maxSize;
this.packedBytesLength = packedBytesLength;
// 4K per page, unless each value is > 4K:
valuesPerBlock = Math.max(1, 4096/packedBytesLength);
}
public void copyFrom(HeapPointWriter other) {
if (docIDs.length < other.nextWrite) {
throw new IllegalStateException("docIDs.length=" + docIDs.length + " other.nextWrite=" + other.nextWrite);
}
System.arraycopy(other.docIDs, 0, docIDs, 0, other.nextWrite);
System.arraycopy(other.ords, 0, ords, 0, other.nextWrite);
for(byte[] block : other.blocks) {
blocks.add(block.clone());
}
nextWrite = other.nextWrite;
}
void readPackedValue(int index, byte[] bytes) {
assert bytes.length == packedBytesLength;
int block = index / valuesPerBlock;
int blockIndex = index % valuesPerBlock;
System.arraycopy(blocks.get(block), blockIndex * packedBytesLength, bytes, 0, packedBytesLength);
}
void writePackedValue(int index, byte[] bytes) {
assert bytes.length == packedBytesLength;
int block = index / valuesPerBlock;
int blockIndex = index % valuesPerBlock;
//System.out.println("writePackedValue: index=" + index + " bytes.length=" + bytes.length + " block=" + block + " blockIndex=" + blockIndex + " valuesPerBlock=" + valuesPerBlock);
while (blocks.size() <= block) {
// If this is the last block, only allocate as large as necessary for maxSize:
int valuesInBlock = Math.min(valuesPerBlock, maxSize - (blocks.size() * valuesPerBlock));
blocks.add(new byte[valuesInBlock*packedBytesLength]);
}
System.arraycopy(bytes, 0, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength);
}
private int[] growExact(int[] arr, int size) {
assert size > arr.length;
int[] newArr = new int[size];
System.arraycopy(arr, 0, newArr, 0, arr.length);
return newArr;
}
private long[] growExact(long[] arr, int size) {
assert size > arr.length;
long[] newArr = new long[size];
System.arraycopy(arr, 0, newArr, 0, arr.length);
return newArr;
}
@Override
public void append(byte[] packedValue, long ord, int docID) {
assert closed == false;
assert packedValue.length == packedBytesLength;
if (ords.length == nextWrite) {
int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, RamUsageEstimator.NUM_BYTES_INT));
assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite;
ords = growExact(ords, nextSize);
docIDs = growExact(docIDs, nextSize);
}
writePackedValue(nextWrite, packedValue);
ords[nextWrite] = ord;
docIDs[nextWrite] = docID;
nextWrite++;
}
@Override
public PointReader getReader(long start) {
return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, ords, docIDs, (int) start, nextWrite);
}
@Override
public void close() {
closed = true;
}
@Override
public void destroy() {
}
@Override
public String toString() {
return "HeapPointWriter(count=" + nextWrite + " alloc=" + ords.length + ")";
}
}

View File

@ -0,0 +1,90 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.EOFException;
import java.io.IOException;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.RamUsageEstimator;
/** Reads points from disk in a fixed-with format, previously written with {@link OfflinePointWriter}. */
final class OfflinePointReader implements PointReader {
long countLeft;
private final IndexInput in;
private final byte[] packedValue;
private long ord;
private int docID;
final int bytesPerDoc;
OfflinePointReader(Directory tempDir, String tempFileName, int packedBytesLength, long start, long length) throws IOException {
this(tempDir.openInput(tempFileName, IOContext.READONCE), packedBytesLength, start, length);
}
private OfflinePointReader(IndexInput in, int packedBytesLength, long start, long length) throws IOException {
this.in = in;
bytesPerDoc = packedBytesLength + RamUsageEstimator.NUM_BYTES_LONG + RamUsageEstimator.NUM_BYTES_INT;
long seekFP = start * bytesPerDoc;
in.seek(seekFP);
this.countLeft = length;
packedValue = new byte[packedBytesLength];
}
@Override
public boolean next() throws IOException {
if (countLeft >= 0) {
if (countLeft == 0) {
return false;
}
countLeft--;
}
try {
in.readBytes(packedValue, 0, packedValue.length);
} catch (EOFException eofe) {
assert countLeft == -1;
return false;
}
ord = in.readLong();
docID = in.readInt();
return true;
}
@Override
public byte[] packedValue() {
return packedValue;
}
@Override
public long ord() {
return ord;
}
@Override
public int docID() {
return docID;
}
@Override
public void close() throws IOException {
in.close();
}
}

View File

@ -0,0 +1,86 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.RamUsageEstimator;
/** Writes points to disk in a fixed-with format. */
final class OfflinePointWriter implements PointWriter {
final Directory tempDir;
final IndexOutput out;
final int packedBytesLength;
final int bytesPerDoc;
private long count;
private boolean closed;
public OfflinePointWriter(Directory tempDir, String tempFileNamePrefix, int packedBytesLength) throws IOException {
this.out = tempDir.createTempOutput(tempFileNamePrefix, "bkd", IOContext.DEFAULT);
this.tempDir = tempDir;
this.packedBytesLength = packedBytesLength;
bytesPerDoc = packedBytesLength + RamUsageEstimator.NUM_BYTES_LONG + RamUsageEstimator.NUM_BYTES_INT;
}
/** Initializes on an already written/closed file, just so consumers can use {@link #getReader} to read the file. */
public OfflinePointWriter(Directory tempDir, IndexOutput out, int packedBytesLength, long count) {
this.out = out;
this.tempDir = tempDir;
this.packedBytesLength = packedBytesLength;
bytesPerDoc = packedBytesLength + RamUsageEstimator.NUM_BYTES_LONG + RamUsageEstimator.NUM_BYTES_INT;
this.count = count;
closed = true;
}
@Override
public void append(byte[] packedValue, long ord, int docID) throws IOException {
assert packedValue.length == packedBytesLength;
out.writeBytes(packedValue, 0, packedValue.length);
out.writeLong(ord);
out.writeInt(docID);
count++;
}
@Override
public PointReader getReader(long start) throws IOException {
assert closed;
return new OfflinePointReader(tempDir, out.getName(), packedBytesLength, start, count-start);
}
@Override
public void close() throws IOException {
out.close();
closed = true;
}
@Override
public void destroy() throws IOException {
tempDir.deleteFile(out.getName());
}
@Override
public String toString() {
return "OfflinePointWriter(count=" + count + " tempFileName=" + out.getName() + ")";
}
}

View File

@ -0,0 +1,40 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.Closeable;
import java.io.IOException;
/** One pass iterator through all points previously written with a
* {@link PointWriter}, abstracting away whether points a read
* from (offline) disk or simple arrays in heap. */
interface PointReader extends Closeable {
/** Returns false once iteration is done, else true. */
boolean next() throws IOException;
/** Returns the packed byte[] value */
byte[] packedValue();
/** Point ordinal */
long ord();
/** DocID for this point */
int docID();
}

View File

@ -0,0 +1,36 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.Closeable;
import java.io.IOException;
/** Appends many points, and then at the end provides a {@link PointReader} to iterate
* those points. This abstracts away whether we write to disk, or use simple arrays
* in heap. */
interface PointWriter extends Closeable {
/** Add a new point */
void append(byte[] packedValue, long ord, int docID) throws IOException;
/** Returns a {@link PointReader} iterator to step through all previously added points */
PointReader getReader(long startPoint) throws IOException;
/** Removes any temp files behind this writer */
void destroy() throws IOException;
}

View File

@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Block KD-tree, implementing the generic spatial data structure described in
* <a href="https://www.cs.duke.edu/~pankaj/publications/papers/bkd-sstd.pdf">this paper</a>.
*/
package org.apache.lucene.util.bkd;

View File

@ -0,0 +1,705 @@
package org.apache.lucene.util.bkd;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MockDirectoryWrapper;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.LuceneTestCase.SuppressSysoutChecks;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.RamUsageTester;
import org.apache.lucene.util.TestUtil;
@SuppressSysoutChecks(bugUrl = "Stuff gets printed.")
public class TestBKD extends LuceneTestCase {
public void testBasicInts1D() throws Exception {
try (Directory dir = getDirectory(100)) {
BKDWriter w = new BKDWriter(dir, "tmp", 1, 4, 2, 1.0f);
byte[] scratch = new byte[4];
for(int docID=0;docID<100;docID++) {
BKDUtil.intToBytes(docID, scratch, 0);
w.add(scratch, docID);
}
long indexFP;
try (IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT)) {
indexFP = w.finish(out);
}
try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
in.seek(indexFP);
BKDReader r = new BKDReader(in);
// Simple 1D range query:
final int queryMin = 42;
final int queryMax = 87;
final BitSet hits = new BitSet();
r.intersect(new BKDReader.IntersectVisitor() {
@Override
public void visit(int docID) {
hits.set(docID);
if (VERBOSE) {
System.out.println("visit docID=" + docID);
}
}
@Override
public void visit(int docID, byte[] packedValue) {
int x = BKDUtil.bytesToInt(packedValue, 0);
if (VERBOSE) {
System.out.println("visit docID=" + docID + " x=" + x);
}
if (x >= queryMin && x <= queryMax) {
hits.set(docID);
}
}
@Override
public BKDReader.Relation compare(byte[] minPacked, byte[] maxPacked) {
int min = BKDUtil.bytesToInt(minPacked, 0);
int max = BKDUtil.bytesToInt(maxPacked, 0);
assert max >= min;
if (VERBOSE) {
System.out.println("compare: min=" + min + " max=" + max + " vs queryMin=" + queryMin + " queryMax=" + queryMax);
}
if (max < queryMin || min > queryMax) {
return BKDReader.Relation.QUERY_OUTSIDE_CELL;
} else if (min >= queryMin && max <= queryMax) {
return BKDReader.Relation.CELL_INSIDE_QUERY;
} else {
return BKDReader.Relation.QUERY_CROSSES_CELL;
}
}
});
for(int docID=0;docID<100;docID++) {
boolean expected = docID >= queryMin && docID <= queryMax;
boolean actual = hits.get(docID);
assertEquals("docID=" + docID, expected, actual);
}
}
}
}
public void testRandomIntsNDims() throws Exception {
int numDocs = atLeast(1000);
try (Directory dir = getDirectory(numDocs)) {
int numDims = TestUtil.nextInt(random(), 1, 5);
int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
float maxMB = (float) 0.1 + (3*random().nextFloat());
BKDWriter w = new BKDWriter(dir, "tmp", numDims, 4, maxPointsInLeafNode, maxMB);
if (VERBOSE) {
System.out.println("TEST: numDims=" + numDims + " numDocs=" + numDocs);
}
int[][] docs = new int[numDocs][];
byte[] scratch = new byte[4*numDims];
for(int docID=0;docID<numDocs;docID++) {
int[] values = new int[numDims];
if (VERBOSE) {
System.out.println(" docID=" + docID);
}
for(int dim=0;dim<numDims;dim++) {
values[dim] = random().nextInt();
BKDUtil.intToBytes(values[dim], scratch, dim);
if (VERBOSE) {
System.out.println(" " + dim + " -> " + values[dim]);
}
}
docs[docID] = values;
w.add(scratch, docID);
}
long indexFP;
try (IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT)) {
indexFP = w.finish(out);
}
try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
in.seek(indexFP);
BKDReader r = new BKDReader(in);
int iters = atLeast(100);
for(int iter=0;iter<iters;iter++) {
if (VERBOSE) {
System.out.println("\nTEST: iter=" + iter);
}
// Random N dims rect query:
int[] queryMin = new int[numDims];
int[] queryMax = new int[numDims];
for(int dim=0;dim<numDims;dim++) {
queryMin[dim] = random().nextInt();
queryMax[dim] = random().nextInt();
if (queryMin[dim] > queryMax[dim]) {
int x = queryMin[dim];
queryMin[dim] = queryMax[dim];
queryMax[dim] = x;
}
}
final BitSet hits = new BitSet();
r.intersect(new BKDReader.IntersectVisitor() {
@Override
public void visit(int docID) {
hits.set(docID);
//System.out.println("visit docID=" + docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
//System.out.println("visit check docID=" + docID);
for(int dim=0;dim<numDims;dim++) {
int x = BKDUtil.bytesToInt(packedValue, dim);
if (x < queryMin[dim] || x > queryMax[dim]) {
//System.out.println(" no");
return;
}
}
//System.out.println(" yes");
hits.set(docID);
}
@Override
public BKDReader.Relation compare(byte[] minPacked, byte[] maxPacked) {
boolean crosses = false;
for(int dim=0;dim<numDims;dim++) {
int min = BKDUtil.bytesToInt(minPacked, dim);
int max = BKDUtil.bytesToInt(maxPacked, dim);
assert max >= min;
if (max < queryMin[dim] || min > queryMax[dim]) {
return BKDReader.Relation.QUERY_OUTSIDE_CELL;
} else if (min < queryMin[dim] || max > queryMax[dim]) {
crosses = true;
}
}
if (crosses) {
return BKDReader.Relation.QUERY_CROSSES_CELL;
} else {
return BKDReader.Relation.CELL_INSIDE_QUERY;
}
}
});
for(int docID=0;docID<numDocs;docID++) {
int[] docValues = docs[docID];
boolean expected = true;
for(int dim=0;dim<numDims;dim++) {
int x = docValues[dim];
if (x < queryMin[dim] || x > queryMax[dim]) {
expected = false;
break;
}
}
boolean actual = hits.get(docID);
assertEquals("docID=" + docID, expected, actual);
}
}
}
}
}
// Tests on N-dimensional points where each dimension is a BigInteger
public void testBigIntNDims() throws Exception {
int numDocs = atLeast(1000);
try (Directory dir = getDirectory(numDocs)) {
int numBytesPerDim = TestUtil.nextInt(random(), 2, 30);
int numDims = TestUtil.nextInt(random(), 1, 5);
int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
float maxMB = (float) 0.1 + (3*random().nextFloat());
BKDWriter w = new BKDWriter(dir, "tmp", numDims, numBytesPerDim, maxPointsInLeafNode, maxMB);
BigInteger[][] docs = new BigInteger[numDocs][];
byte[] scratch = new byte[numBytesPerDim*numDims];
for(int docID=0;docID<numDocs;docID++) {
BigInteger[] values = new BigInteger[numDims];
if (VERBOSE) {
System.out.println(" docID=" + docID);
}
for(int dim=0;dim<numDims;dim++) {
values[dim] = randomBigInt(numBytesPerDim);
BKDUtil.bigIntToBytes(values[dim], scratch, dim, numBytesPerDim);
if (VERBOSE) {
System.out.println(" " + dim + " -> " + values[dim]);
}
}
docs[docID] = values;
w.add(scratch, docID);
}
long indexFP;
try (IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT)) {
indexFP = w.finish(out);
}
try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
in.seek(indexFP);
BKDReader r = new BKDReader(in);
int iters = atLeast(100);
for(int iter=0;iter<iters;iter++) {
if (VERBOSE) {
System.out.println("\nTEST: iter=" + iter);
}
// Random N dims rect query:
BigInteger[] queryMin = new BigInteger[numDims];
BigInteger[] queryMax = new BigInteger[numDims];
for(int dim=0;dim<numDims;dim++) {
queryMin[dim] = randomBigInt(numBytesPerDim);
queryMax[dim] = randomBigInt(numBytesPerDim);
if (queryMin[dim].compareTo(queryMax[dim]) > 0) {
BigInteger x = queryMin[dim];
queryMin[dim] = queryMax[dim];
queryMax[dim] = x;
}
}
final BitSet hits = new BitSet();
r.intersect(new BKDReader.IntersectVisitor() {
@Override
public void visit(int docID) {
hits.set(docID);
//System.out.println("visit docID=" + docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
//System.out.println("visit check docID=" + docID);
for(int dim=0;dim<numDims;dim++) {
BigInteger x = BKDUtil.bytesToBigInt(packedValue, dim, numBytesPerDim);
if (x.compareTo(queryMin[dim]) < 0 || x.compareTo(queryMax[dim]) > 0) {
//System.out.println(" no");
return;
}
}
//System.out.println(" yes");
hits.set(docID);
}
@Override
public BKDReader.Relation compare(byte[] minPacked, byte[] maxPacked) {
boolean crosses = false;
for(int dim=0;dim<numDims;dim++) {
BigInteger min = BKDUtil.bytesToBigInt(minPacked, dim, numBytesPerDim);
BigInteger max = BKDUtil.bytesToBigInt(maxPacked, dim, numBytesPerDim);
assert max.compareTo(min) >= 0;
if (max.compareTo(queryMin[dim]) < 0 || min.compareTo(queryMax[dim]) > 0) {
return BKDReader.Relation.QUERY_OUTSIDE_CELL;
} else if (min.compareTo(queryMin[dim]) < 0 || max.compareTo(queryMax[dim]) > 0) {
crosses = true;
}
}
if (crosses) {
return BKDReader.Relation.QUERY_CROSSES_CELL;
} else {
return BKDReader.Relation.CELL_INSIDE_QUERY;
}
}
});
for(int docID=0;docID<numDocs;docID++) {
BigInteger[] docValues = docs[docID];
boolean expected = true;
for(int dim=0;dim<numDims;dim++) {
BigInteger x = docValues[dim];
if (x.compareTo(queryMin[dim]) < 0 || x.compareTo(queryMax[dim]) > 0) {
expected = false;
break;
}
}
boolean actual = hits.get(docID);
assertEquals("docID=" + docID, expected, actual);
}
}
}
}
}
/** Make sure we close open files, delete temp files, etc., on exception */
public void testWithExceptions() throws Exception {
int numDocs = atLeast(10000);
int numBytesPerDim = TestUtil.nextInt(random(), 2, 30);
int numDims = TestUtil.nextInt(random(), 1, 5);
byte[][][] docValues = new byte[numDocs][][];
for(int docID=0;docID<numDocs;docID++) {
byte[][] values = new byte[numDims][];
for(int dim=0;dim<numDims;dim++) {
values[dim] = new byte[numBytesPerDim];
random().nextBytes(values[dim]);
}
docValues[docID] = values;
}
double maxMBHeap = 0.05;
// Keep retrying until we 1) we allow a big enough heap, and 2) we hit a random IOExc from MDW:
boolean done = false;
while (done == false) {
try (MockDirectoryWrapper dir = newMockFSDirectory(createTempDir())) {
try {
dir.setRandomIOExceptionRate(0.05);
dir.setRandomIOExceptionRateOnOpen(0.05);
if (dir instanceof MockDirectoryWrapper) {
dir.setEnableVirusScanner(false);
}
verify(dir, docValues, null, numDims, numBytesPerDim, 50, maxMBHeap);
} catch (IllegalArgumentException iae) {
// This just means we got a too-small maxMB for the maxPointsInLeafNode; just retry w/ more heap
assertTrue(iae.getMessage().contains("either increase maxMBSortInHeap or decrease maxPointsInLeafNode"));
System.out.println(" more heap");
maxMBHeap *= 1.25;
} catch (IOException ioe) {
if (ioe.getMessage().contains("a random IOException")) {
// BKDWriter should fully clean up after itself:
done = true;
} else {
throw ioe;
}
}
String[] files = dir.listAll();
assertTrue("files=" + Arrays.toString(files), files.length == 0 || Arrays.equals(files, new String[] {"extra0"}));
}
}
}
public void testRandomBinaryTiny() throws Exception {
doTestRandomBinary(10);
}
public void testRandomBinarydMedium() throws Exception {
doTestRandomBinary(10000);
}
@Nightly
public void testRandomBinaryBig() throws Exception {
doTestRandomBinary(200000);
}
public void testTooLittleHeap() throws Exception {
try (Directory dir = getDirectory(0)) {
new BKDWriter(dir, "bkd", 1, 16, 1000000, 0.001);
fail("did not hit exception");
} catch (IllegalArgumentException iae) {
// expected
assertTrue(iae.getMessage().contains("either increase maxMBSortInHeap or decrease maxPointsInLeafNode"));
}
}
private void doTestRandomBinary(int count) throws Exception {
int numDocs = TestUtil.nextInt(random(), count, count*2);
int numBytesPerDim = TestUtil.nextInt(random(), 2, 30);
int numDims = TestUtil.nextInt(random(), 1, 5);
byte[][][] docValues = new byte[numDocs][][];
for(int docID=0;docID<numDocs;docID++) {
byte[][] values = new byte[numDims][];
for(int dim=0;dim<numDims;dim++) {
values[dim] = new byte[numBytesPerDim];
random().nextBytes(values[dim]);
}
docValues[docID] = values;
}
verify(docValues, null, numDims, numBytesPerDim);
}
public void testAllEqual() throws Exception {
int numBytesPerDim = TestUtil.nextInt(random(), 2, 30);
int numDims = TestUtil.nextInt(random(), 1, 5);
int numDocs = atLeast(1000);
byte[][][] docValues = new byte[numDocs][][];
for(int docID=0;docID<numDocs;docID++) {
if (docID == 0) {
byte[][] values = new byte[numDims][];
for(int dim=0;dim<numDims;dim++) {
values[dim] = new byte[numBytesPerDim];
random().nextBytes(values[dim]);
}
docValues[docID] = values;
} else {
docValues[docID] = docValues[0];
}
}
verify(docValues, null, numDims, numBytesPerDim);
}
public void testOneDimEqual() throws Exception {
int numBytesPerDim = TestUtil.nextInt(random(), 2, 30);
int numDims = TestUtil.nextInt(random(), 1, 5);
int numDocs = atLeast(1000);
int theEqualDim = random().nextInt(numDims);
byte[][][] docValues = new byte[numDocs][][];
for(int docID=0;docID<numDocs;docID++) {
byte[][] values = new byte[numDims][];
for(int dim=0;dim<numDims;dim++) {
values[dim] = new byte[numBytesPerDim];
random().nextBytes(values[dim]);
}
docValues[docID] = values;
if (docID > 0) {
docValues[docID][theEqualDim] = docValues[0][theEqualDim];
}
}
verify(docValues, null, numDims, numBytesPerDim);
}
public void testMultiValued() throws Exception {
int numBytesPerDim = TestUtil.nextInt(random(), 2, 30);
int numDims = TestUtil.nextInt(random(), 1, 5);
int numDocs = atLeast(1000);
List<byte[][]> docValues = new ArrayList<>();
List<Integer> docIDs = new ArrayList<>();
for(int docID=0;docID<numDocs;docID++) {
int numValuesInDoc = TestUtil.nextInt(random(), 1, 5);
for(int ord=0;ord<numValuesInDoc;ord++) {
docIDs.add(docID);
byte[][] values = new byte[numDims][];
for(int dim=0;dim<numDims;dim++) {
values[dim] = new byte[numBytesPerDim];
random().nextBytes(values[dim]);
}
docValues.add(values);
}
}
byte[][][] docValuesArray = docValues.toArray(new byte[docValues.size()][][]);
int[] docIDsArray = new int[docIDs.size()];
for(int i=0;i<docIDsArray.length;i++) {
docIDsArray[i] = docIDs.get(i);
}
verify(docValuesArray, docIDsArray, numDims, numBytesPerDim);
}
/** docIDs can be null, for the single valued case, else it maps value to docID */
private void verify(byte[][][] docValues, int[] docIDs, int numDims, int numBytesPerDim) throws Exception {
try (Directory dir = getDirectory(docValues.length)) {
while (true) {
int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
double maxMB = (float) 0.1 + (3*random().nextDouble());
try {
verify(dir, docValues, docIDs, numDims, numBytesPerDim, maxPointsInLeafNode, maxMB);
return;
} catch (IllegalArgumentException iae) {
// This just means we got a too-small maxMB for the maxPointsInLeafNode; just retry
assertTrue(iae.getMessage().contains("either increase maxMBSortInHeap or decrease maxPointsInLeafNode"));
}
}
}
}
private void verify(Directory dir, byte[][][] docValues, int[] docIDs, int numDims, int numBytesPerDim, int maxPointsInLeafNode, double maxMB) throws Exception {
int numValues = docValues.length;
if (VERBOSE) {
System.out.println("TEST: numValues=" + numValues + " numDims=" + numDims + " numBytesPerDim=" + numBytesPerDim + " maxPointsInLeafNode=" + maxPointsInLeafNode + " maxMB=" + maxMB);
}
long indexFP;
try (BKDWriter w = new BKDWriter(dir, "tmp", numDims, numBytesPerDim, maxPointsInLeafNode, maxMB)) {
byte[] scratch = new byte[numBytesPerDim*numDims];
for(int ord=0;ord<numValues;ord++) {
int docID;
if (docIDs == null) {
docID = ord;
} else {
docID = docIDs[ord];
}
if (VERBOSE) {
System.out.println(" ord=" + ord + " docID=" + docID);
}
for(int dim=0;dim<numDims;dim++) {
if (VERBOSE) {
System.out.println(" " + dim + " -> " + new BytesRef(docValues[ord][dim]));
}
System.arraycopy(docValues[ord][dim], 0, scratch, dim*numBytesPerDim, numBytesPerDim);
}
w.add(scratch, docID);
}
boolean success = false;
try (IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT)) {
indexFP = w.finish(out);
success = true;
} finally {
if (success == false) {
IOUtils.deleteFilesIgnoringExceptions(dir, "bkd");
}
}
}
try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
in.seek(indexFP);
BKDReader r = new BKDReader(in);
int iters = atLeast(100);
for(int iter=0;iter<iters;iter++) {
if (VERBOSE) {
System.out.println("\nTEST: iter=" + iter);
}
// Random N dims rect query:
byte[][] queryMin = new byte[numDims][];
byte[][] queryMax = new byte[numDims][];
for(int dim=0;dim<numDims;dim++) {
queryMin[dim] = new byte[numBytesPerDim];
random().nextBytes(queryMin[dim]);
queryMax[dim] = new byte[numBytesPerDim];
random().nextBytes(queryMax[dim]);
if (BKDUtil.compare(numBytesPerDim, queryMin[dim], 0, queryMax[dim], 0) > 0) {
byte[] x = queryMin[dim];
queryMin[dim] = queryMax[dim];
queryMax[dim] = x;
}
}
final BitSet hits = new BitSet();
r.intersect(new BKDReader.IntersectVisitor() {
@Override
public void visit(int docID) {
hits.set(docID);
//System.out.println("visit docID=" + docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
//System.out.println("visit check docID=" + docID);
for(int dim=0;dim<numDims;dim++) {
if (BKDUtil.compare(numBytesPerDim, packedValue, dim, queryMin[dim], 0) < 0 ||
BKDUtil.compare(numBytesPerDim, packedValue, dim, queryMax[dim], 0) > 0) {
//System.out.println(" no");
return;
}
}
//System.out.println(" yes");
hits.set(docID);
}
@Override
public BKDReader.Relation compare(byte[] minPacked, byte[] maxPacked) {
boolean crosses = false;
for(int dim=0;dim<numDims;dim++) {
BigInteger min = BKDUtil.bytesToBigInt(minPacked, dim, numBytesPerDim);
BigInteger max = BKDUtil.bytesToBigInt(maxPacked, dim, numBytesPerDim);
assert max.compareTo(min) >= 0;
if (BKDUtil.compare(numBytesPerDim, maxPacked, dim, queryMin[dim], 0) < 0 ||
BKDUtil.compare(numBytesPerDim, minPacked, dim, queryMax[dim], 0) > 0) {
return BKDReader.Relation.QUERY_OUTSIDE_CELL;
} else if (BKDUtil.compare(numBytesPerDim, minPacked, dim, queryMin[dim], 0) < 0 ||
BKDUtil.compare(numBytesPerDim, maxPacked, dim, queryMax[dim], 0) > 0) {
crosses = true;
}
}
if (crosses) {
return BKDReader.Relation.QUERY_CROSSES_CELL;
} else {
return BKDReader.Relation.CELL_INSIDE_QUERY;
}
}
});
BitSet expected = new BitSet();
for(int ord=0;ord<numValues;ord++) {
boolean matches = true;
for(int dim=0;dim<numDims;dim++) {
byte[] x = docValues[ord][dim];
if (BKDUtil.compare(numBytesPerDim, x, 0, queryMin[dim], 0) < 0 ||
BKDUtil.compare(numBytesPerDim, x, 0, queryMax[dim], 0) > 0) {
matches = false;
break;
}
}
if (matches) {
int docID;
if (docIDs == null) {
docID = ord;
} else {
docID = docIDs[ord];
}
expected.set(docID);
}
}
int limit = Math.max(expected.length(), hits.length());
for(int docID=0;docID<limit;docID++) {
assertEquals("docID=" + docID, expected.get(docID), hits.get(docID));
}
}
}
dir.deleteFile("bkd");
}
private BigInteger randomBigInt(int numBytes) {
BigInteger x = new BigInteger(numBytes*8-1, random());
if (random().nextBoolean()) {
x = x.negate();
}
return x;
}
private Directory getDirectory(int numPoints) {
Directory dir;
if (numPoints > 100000) {
dir = newFSDirectory(createTempDir("TestBKDTree"));
} else {
dir = newDirectory();
}
System.out.println("DIR: " + dir);
if (dir instanceof MockDirectoryWrapper) {
((MockDirectoryWrapper) dir).setEnableVirusScanner(false);
}
return dir;
}
}

View File

@ -24,6 +24,7 @@ import java.util.Comparator;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.OfflineSorter;
@ -105,7 +106,7 @@ public class ExternalRefSorter implements BytesRefSorter, Closeable {
*/
class ByteSequenceIterator implements BytesRefIterator {
private final OfflineSorter.ByteSequencesReader reader;
private BytesRef scratch = new BytesRef();
private BytesRefBuilder scratch = new BytesRefBuilder();
public ByteSequenceIterator(OfflineSorter.ByteSequencesReader reader) {
this.reader = reader;
@ -118,17 +119,15 @@ public class ExternalRefSorter implements BytesRefSorter, Closeable {
}
boolean success = false;
try {
byte[] next = reader.read();
if (next != null) {
scratch.bytes = next;
scratch.length = next.length;
scratch.offset = 0;
} else {
if (reader.read(scratch) == false) {
IOUtils.close(reader);
scratch = null;
}
success = true;
return scratch;
if (scratch == null) {
return null;
}
return scratch.get();
} finally {
if (!success) {
IOUtils.closeWhileHandlingException(reader);

View File

@ -440,7 +440,6 @@ public class MockDirectoryWrapper extends BaseDirectoryWrapper {
if (randomState.nextDouble() < randomIOExceptionRate) {
if (LuceneTestCase.VERBOSE) {
System.out.println(Thread.currentThread().getName() + ": MockDirectoryWrapper: now throw random exception" + (message == null ? "" : " (" + message + ")"));
new Throwable().printStackTrace(System.out);
}
throw new IOException("a random IOException" + (message == null ? "" : " (" + message + ")"));
}
@ -567,9 +566,6 @@ public class MockDirectoryWrapper extends BaseDirectoryWrapper {
}
}
if (crashed) {
throw new IOException("cannot createOutput after crash");
}
unSyncedFiles.add(name);
createdFiles.add(name);
@ -607,6 +603,39 @@ public class MockDirectoryWrapper extends BaseDirectoryWrapper {
}
}
@Override
public synchronized IndexOutput createTempOutput(String prefix, String suffix, IOContext context) throws IOException {
maybeThrowDeterministicException();
maybeThrowIOExceptionOnOpen("temp: prefix=" + prefix + " suffix=" + suffix);
maybeYield();
if (failOnCreateOutput) {
maybeThrowDeterministicException();
}
if (crashed) {
throw new IOException("cannot createTempOutput after crash");
}
init();
IndexOutput delegateOutput = in.createTempOutput(prefix, suffix, LuceneTestCase.newIOContext(randomState, context));
String name = delegateOutput.getName();
unSyncedFiles.add(name);
createdFiles.add(name);
final IndexOutput io = new MockIndexOutputWrapper(this, delegateOutput, name);
addFileHandle(io, name, Handle.Output);
openFilesForWrite.add(name);
// throttling REALLY slows down tests, so don't do it very often for SOMETIMES.
if (throttling == Throttling.ALWAYS ||
(throttling == Throttling.SOMETIMES && randomState.nextInt(200) == 0)) {
if (LuceneTestCase.VERBOSE) {
System.out.println("MockDirectoryWrapper: throttling indexOutput (" + name + ")");
}
return throttledOutput.newFromDelegate(io);
} else {
return io;
}
}
private static enum Handle {
Input, Output, Slice
}