LUCENE-8885: Optimise BKD reader by exploiting cardinality information stored on leaves (#746)

The commit adds the method InstersectVisitor#visit(DocIdSetIterator, byte[]).
This commit is contained in:
Ignacio Vera 2019-07-01 06:15:03 +02:00 committed by GitHub
parent d6345439dc
commit db68634c67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 36 deletions

View File

@ -153,7 +153,10 @@ Optimizations
* LUCENE-8868: New storing strategy for BKD tree leaves with low cardinality.
It stores the distinct values once with the cardinality value reducing the
storage cost.
storage cost. (Ignacio Vera)
* LUCENE-8885: Optimise BKD reader by exploiting cardinality information stored
on leaves. (Ignacio Vera)
Test Framework

View File

@ -208,6 +208,16 @@ public abstract class PointValues {
* docID order. */
void visit(int docID, byte[] packedValue) throws IOException;
/** Similar to {@link IntersectVisitor#visit(int, byte[])} but in this case the packedValue
* can have more than one docID associated to it. The provided iterator should not escape the
* scope of this method so that implementations of PointValues are free to reuse it,*/
default void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
int docID;
while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
visit(docID, packedValue);
}
}
/** Called for non-leaf cells to test how the cell relates to the query, to
* determine how to further recurse down the tree. */
Relation compare(byte[] minPackedValue, byte[] maxPackedValue);

View File

@ -22,6 +22,7 @@ import java.util.Arrays;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Accountable;
@ -332,7 +333,7 @@ public final class BKDReader extends PointValues implements Accountable {
/** Used to track all state for a single call to {@link #intersect}. */
public static final class IntersectState {
final IndexInput in;
final int[] scratchDocIDs;
final BKDReaderDocIDSetIterator scratchIterator;
final byte[] scratchDataPackedValue, scratchMinIndexPackedValue, scratchMaxIndexPackedValue;
final int[] commonPrefixLengths;
@ -348,7 +349,7 @@ public final class BKDReader extends PointValues implements Accountable {
this.in = in;
this.visitor = visitor;
this.commonPrefixLengths = new int[numDims];
this.scratchDocIDs = new int[maxPointsInLeafNode];
this.scratchIterator = new BKDReaderDocIDSetIterator(maxPointsInLeafNode);
this.scratchDataPackedValue = new byte[packedBytesLength];
this.scratchMinIndexPackedValue = new byte[packedIndexBytesLength];
this.scratchMaxIndexPackedValue = new byte[packedIndexBytesLength];
@ -411,10 +412,10 @@ public final class BKDReader extends PointValues implements Accountable {
public void visitLeafBlockValues(IndexTree index, IntersectState state) throws IOException {
// Leaf node; scan and filter all points in this block:
int count = readDocIDs(state.in, index.getLeafBlockFP(), state.scratchDocIDs);
int count = readDocIDs(state.in, index.getLeafBlockFP(), state.scratchIterator);
// Again, this time reading values and checking with the visitor
visitDocValues(state.commonPrefixLengths, state.scratchDataPackedValue, state.scratchMinIndexPackedValue, state.scratchMaxIndexPackedValue, state.in, state.scratchDocIDs, count, state.visitor);
visitDocValues(state.commonPrefixLengths, state.scratchDataPackedValue, state.scratchMinIndexPackedValue, state.scratchMaxIndexPackedValue, state.in, state.scratchIterator, count, state.visitor);
}
private void visitDocIDs(IndexInput in, long blockFP, IntersectVisitor visitor) throws IOException {
@ -428,28 +429,28 @@ public final class BKDReader extends PointValues implements Accountable {
DocIdsWriter.readInts(in, count, visitor);
}
int readDocIDs(IndexInput in, long blockFP, int[] docIDs) throws IOException {
int readDocIDs(IndexInput in, long blockFP, BKDReaderDocIDSetIterator iterator) throws IOException {
in.seek(blockFP);
// How many points are stored in this leaf cell:
int count = in.readVInt();
DocIdsWriter.readInts(in, count, docIDs);
DocIdsWriter.readInts(in, count, iterator.docIDs);
return count;
}
void visitDocValues(int[] commonPrefixLengths, byte[] scratchDataPackedValue, byte[] scratchMinIndexPackedValue, byte[] scratchMaxIndexPackedValue,
IndexInput in, int[] docIDs, int count, IntersectVisitor visitor) throws IOException {
IndexInput in, BKDReaderDocIDSetIterator scratchIterator, int count, IntersectVisitor visitor) throws IOException {
if (version >= BKDWriter.VERSION_LOW_CARDINALITY_LEAVES) {
visitDocValuesWithCardinality(commonPrefixLengths, scratchDataPackedValue, scratchMinIndexPackedValue, scratchMaxIndexPackedValue, in, docIDs, count, visitor);
visitDocValuesWithCardinality(commonPrefixLengths, scratchDataPackedValue, scratchMinIndexPackedValue, scratchMaxIndexPackedValue, in, scratchIterator, count, visitor);
} else {
visitDocValuesNoCardinality(commonPrefixLengths, scratchDataPackedValue, scratchMinIndexPackedValue, scratchMaxIndexPackedValue, in, docIDs, count, visitor);
visitDocValuesNoCardinality(commonPrefixLengths, scratchDataPackedValue, scratchMinIndexPackedValue, scratchMaxIndexPackedValue, in, scratchIterator, count, visitor);
}
}
void visitDocValuesNoCardinality(int[] commonPrefixLengths, byte[] scratchDataPackedValue, byte[] scratchMinIndexPackedValue, byte[] scratchMaxIndexPackedValue,
IndexInput in, int[] docIDs, int count, IntersectVisitor visitor) throws IOException {
IndexInput in, BKDReaderDocIDSetIterator scratchIterator, int count, IntersectVisitor visitor) throws IOException {
readCommonPrefixes(commonPrefixLengths, scratchDataPackedValue, in);
if (numIndexDims != 1 && version >= BKDWriter.VERSION_LEAF_STORES_BOUNDS) {
@ -474,7 +475,7 @@ public final class BKDReader extends PointValues implements Accountable {
if (r == Relation.CELL_INSIDE_QUERY) {
for (int i = 0; i < count; ++i) {
visitor.visit(docIDs[i]);
visitor.visit(scratchIterator.docIDs[i]);
}
return;
}
@ -486,21 +487,21 @@ public final class BKDReader extends PointValues implements Accountable {
int compressedDim = readCompressedDim(in);
if (compressedDim == -1) {
visitUniqueRawDocValues(scratchDataPackedValue, docIDs, count, visitor);
visitUniqueRawDocValues(scratchDataPackedValue, scratchIterator, count, visitor);
} else {
visitCompressedDocValues(commonPrefixLengths, scratchDataPackedValue, in, docIDs, count, visitor, compressedDim);
visitCompressedDocValues(commonPrefixLengths, scratchDataPackedValue, in, scratchIterator, count, visitor, compressedDim);
}
}
void visitDocValuesWithCardinality(int[] commonPrefixLengths, byte[] scratchDataPackedValue, byte[] scratchMinIndexPackedValue, byte[] scratchMaxIndexPackedValue,
IndexInput in, int[] docIDs, int count, IntersectVisitor visitor) throws IOException {
IndexInput in, BKDReaderDocIDSetIterator scratchIterator, int count, IntersectVisitor visitor) throws IOException {
readCommonPrefixes(commonPrefixLengths, scratchDataPackedValue, in);
int compressedDim = readCompressedDim(in);
if (compressedDim == -1) {
// all values are the same
visitor.grow(count);
visitUniqueRawDocValues(scratchDataPackedValue, docIDs, count, visitor);
visitUniqueRawDocValues(scratchDataPackedValue, scratchIterator, count, visitor);
} else {
if (numIndexDims != 1) {
byte[] minPackedValue = scratchMinIndexPackedValue;
@ -524,7 +525,7 @@ public final class BKDReader extends PointValues implements Accountable {
if (r == Relation.CELL_INSIDE_QUERY) {
for (int i = 0; i < count; ++i) {
visitor.visit(docIDs[i]);
visitor.visit(scratchIterator.docIDs[i]);
}
return;
}
@ -533,10 +534,10 @@ public final class BKDReader extends PointValues implements Accountable {
}
if (compressedDim == -2) {
// low cardinality values
visitSparseRawDocValues(commonPrefixLengths, scratchDataPackedValue, in, docIDs, count, visitor);
visitSparseRawDocValues(commonPrefixLengths, scratchDataPackedValue, in, scratchIterator, count, visitor);
} else {
// high cardinality
visitCompressedDocValues(commonPrefixLengths, scratchDataPackedValue, in, docIDs, count, visitor, compressedDim);
visitCompressedDocValues(commonPrefixLengths, scratchDataPackedValue, in, scratchIterator, count, visitor, compressedDim);
}
}
}
@ -550,7 +551,7 @@ public final class BKDReader extends PointValues implements Accountable {
}
// read cardinality and point
private void visitSparseRawDocValues(int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in, int[] docIDs, int count, IntersectVisitor visitor) throws IOException {
private void visitSparseRawDocValues(int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in, BKDReaderDocIDSetIterator scratchIterator, int count, IntersectVisitor visitor) throws IOException {
int i;
for (i = 0; i < count;) {
int length = in.readVInt();
@ -558,9 +559,8 @@ public final class BKDReader extends PointValues implements Accountable {
int prefix = commonPrefixLengths[dim];
in.readBytes(scratchPackedValue, dim*bytesPerDim + prefix, bytesPerDim - prefix);
}
for (int j = i; j < i + length; j++) {
visitor.visit(docIDs[j], scratchPackedValue);
}
scratchIterator.reset(i, length);
visitor.visit(scratchIterator, scratchPackedValue);
i += length;
}
if (i != count) {
@ -569,13 +569,12 @@ public final class BKDReader extends PointValues implements Accountable {
}
// point is under commonPrefix
private void visitUniqueRawDocValues(byte[] scratchPackedValue, int[] docIDs, int count, IntersectVisitor visitor) throws IOException {
for (int i = 0; i < count; i++) {
visitor.visit(docIDs[i], scratchPackedValue);
}
private void visitUniqueRawDocValues(byte[] scratchPackedValue, BKDReaderDocIDSetIterator scratchIterator, int count, IntersectVisitor visitor) throws IOException {
scratchIterator.reset(0, count);
visitor.visit(scratchIterator, scratchPackedValue);
}
private void visitCompressedDocValues(int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in, int[] docIDs, int count, IntersectVisitor visitor, int compressedDim) throws IOException {
private void visitCompressedDocValues(int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in, BKDReaderDocIDSetIterator scratchIterator, int count, IntersectVisitor visitor, int compressedDim) throws IOException {
// the byte at `compressedByteOffset` is compressed using run-length compression,
// other suffix bytes are stored verbatim
final int compressedByteOffset = compressedDim * bytesPerDim + commonPrefixLengths[compressedDim];
@ -589,7 +588,7 @@ public final class BKDReader extends PointValues implements Accountable {
int prefix = commonPrefixLengths[dim];
in.readBytes(scratchPackedValue, dim*bytesPerDim + prefix, bytesPerDim - prefix);
}
visitor.visit(docIDs[i+j], scratchPackedValue);
visitor.visit(scratchIterator.docIDs[i+j], scratchPackedValue);
}
i += runLen;
}
@ -641,10 +640,10 @@ public final class BKDReader extends PointValues implements Accountable {
// In the unbalanced case it's possible the left most node only has one child:
if (state.index.nodeExists()) {
// Leaf node; scan and filter all points in this block:
int count = readDocIDs(state.in, state.index.getLeafBlockFP(), state.scratchDocIDs);
int count = readDocIDs(state.in, state.index.getLeafBlockFP(), state.scratchIterator);
// Again, this time reading values and checking with the visitor
visitDocValues(state.commonPrefixLengths, state.scratchDataPackedValue, state.scratchMinIndexPackedValue, state.scratchMaxIndexPackedValue, state.in, state.scratchDocIDs, count, state.visitor);
visitDocValues(state.commonPrefixLengths, state.scratchDataPackedValue, state.scratchMinIndexPackedValue, state.scratchMaxIndexPackedValue, state.in, state.scratchIterator, count, state.visitor);
}
} else {
@ -780,4 +779,53 @@ public final class BKDReader extends PointValues implements Accountable {
public boolean isLeafNode(int nodeID) {
return nodeID >= leafNodeOffset;
}
/**
* Reusable {@link DocIdSetIterator} to handle low cardinality leaves. */
protected static class BKDReaderDocIDSetIterator extends DocIdSetIterator {
private int idx;
private int length;
private int offset;
private int docID;
final int[] docIDs;
public BKDReaderDocIDSetIterator(int maxPointsInLeafNode) {
this.docIDs = new int[maxPointsInLeafNode];
}
@Override
public int docID() {
return docID;
}
private void reset(int offset, int length) {
this.offset = offset;
this.length = length;
assert offset + length <= docIDs.length;
this.docID = -1;
this.idx = 0;
}
@Override
public int nextDoc() throws IOException {
if (idx == length) {
docID = DocIdSetIterator.NO_MORE_DOCS;
} else {
docID = docIDs[offset + idx];
idx++;
}
return docID;
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return length;
}
}
}

View File

@ -291,10 +291,10 @@ public class BKDWriter implements Closeable {
return false;
}
//System.out.println(" new block @ fp=" + state.in.getFilePointer());
docsInBlock = bkd.readDocIDs(state.in, state.in.getFilePointer(), state.scratchDocIDs);
docsInBlock = bkd.readDocIDs(state.in, state.in.getFilePointer(), state.scratchIterator);
assert docsInBlock > 0;
docBlockUpto = 0;
bkd.visitDocValues(state.commonPrefixLengths, state.scratchDataPackedValue, state.scratchMinIndexPackedValue, state.scratchMaxIndexPackedValue, state.in, state.scratchDocIDs, docsInBlock, new IntersectVisitor() {
bkd.visitDocValues(state.commonPrefixLengths, state.scratchDataPackedValue, state.scratchMinIndexPackedValue, state.scratchMaxIndexPackedValue, state.in, state.scratchIterator, docsInBlock, new IntersectVisitor() {
int i = 0;
@Override
@ -304,7 +304,7 @@ public class BKDWriter implements Closeable {
@Override
public void visit(int docID, byte[] packedValue) {
assert docID == state.scratchDocIDs[i];
assert docID == state.scratchIterator.docIDs[i];
System.arraycopy(packedValue, 0, packedValues, i * bkd.packedBytesLength, bkd.packedBytesLength);
i++;
}
@ -320,7 +320,7 @@ public class BKDWriter implements Closeable {
}
final int index = docBlockUpto++;
int oldDocID = state.scratchDocIDs[index];
int oldDocID = state.scratchIterator.docIDs[index];
int mappedDocID;
if (docMap == null) {

View File

@ -31,6 +31,7 @@ import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.mockfile.ExtrasFS;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.CorruptingIndexOutput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FilterDirectory;
@ -844,6 +845,28 @@ public class TestBKD extends LuceneTestCase {
hits.set(docID);
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (random().nextBoolean()) {
// check the default method is correct
IntersectVisitor.super.visit(iterator, packedValue);
} else {
assertEquals(iterator.docID(), -1);
int cost = Math.toIntExact(iterator.cost());
int numberOfPoints = 0;
int docID;
while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
assertEquals(iterator.docID(), docID);
visit(docID, packedValue);
numberOfPoints++;
}
assertEquals(cost, numberOfPoints);
assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS);
assertEquals(iterator.nextDoc(), DocIdSetIterator.NO_MORE_DOCS);
assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS);
}
}
@Override
public Relation compare(byte[] minPacked, byte[] maxPacked) {
boolean crosses = false;