Hide the internal data structure of HeapPointWriter (#12762)

This commit is contained in:
Ignacio Vera 2023-11-24 13:58:39 +01:00 committed by GitHub
parent f460d612b5
commit 74085cd1b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 167 additions and 182 deletions

View File

@ -363,6 +363,8 @@ Other
overflows and slices that are too large. Some bits of code are simplified. Documentation is updated and expanded. overflows and slices that are too large. Some bits of code are simplified. Documentation is updated and expanded.
(Stefan Vodita) (Stefan Vodita)
* GITHUB#12762: Refactor BKD HeapPointWriter to hide the internal data structure. (Ignacio Vera)
======================== Lucene 9.8.0 ======================= ======================== Lucene 9.8.0 =======================
API Changes API Changes

View File

@ -19,8 +19,6 @@ package org.apache.lucene.util.bkd;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntroSelector; import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.IntroSorter;
@ -181,12 +179,8 @@ public final class BKDRadixSelector {
break; break;
} else { } else {
// Check common prefix and adjust histogram // Check common prefix and adjust histogram
final int startIndex = final int startIndex = Math.min(dimCommonPrefix, config.bytesPerDim);
(dimCommonPrefix > config.bytesPerDim) ? config.bytesPerDim : dimCommonPrefix; final int endIndex = Math.min(commonPrefixPosition, config.bytesPerDim);
final int endIndex =
(commonPrefixPosition > config.bytesPerDim)
? config.bytesPerDim
: commonPrefixPosition;
packedValueDocID = pointValue.packedValueDocIDBytes(); packedValueDocID = pointValue.packedValueDocIDBytes();
int j = int j =
Arrays.mismatch( Arrays.mismatch(
@ -427,24 +421,13 @@ public final class BKDRadixSelector {
@Override @Override
protected int byteAt(int i, int k) { protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k; assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) { return points.byteAt(i, k < dimCmpBytes ? dimOffset + k : dataOffset + k);
// dim bytes
return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
} else {
// data bytes
return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
}
} }
@Override @Override
protected Selector getFallbackSelector(int d) { protected Selector getFallbackSelector(int d) {
final int skypedBytes = d + commonPrefixLength; final int skypedBytes = d + commonPrefixLength;
final int dimStart = dim * config.bytesPerDim; final int dimStart = dim * config.bytesPerDim;
// data length is composed by the data dimensions plus the docID
final int dataLength =
(config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
final ByteArrayComparator dimComparator =
ArrayUtil.getUnsignedComparator(config.bytesPerDim);
return new IntroSelector() { return new IntroSelector() {
@Override @Override
@ -455,61 +438,31 @@ public final class BKDRadixSelector {
@Override @Override
protected void setPivot(int i) { protected void setPivot(int i) {
if (skypedBytes < config.bytesPerDim) { if (skypedBytes < config.bytesPerDim) {
System.arraycopy( points.copyDim(i, dimStart, scratch, 0);
points.block,
i * config.bytesPerDoc + dim * config.bytesPerDim,
scratch,
0,
config.bytesPerDim);
} }
System.arraycopy( points.copyDataDimsAndDoc(i, scratch, config.bytesPerDim);
points.block,
i * config.bytesPerDoc + config.packedIndexBytesLength,
scratch,
config.bytesPerDim,
dataLength);
} }
@Override @Override
protected int compare(int i, int j) { protected int compare(int i, int j) {
if (skypedBytes < config.bytesPerDim) { if (skypedBytes < config.bytesPerDim) {
int iOffset = i * config.bytesPerDoc; int cmp = points.compareDim(i, j, dimStart);
int jOffset = j * config.bytesPerDoc;
int cmp =
dimComparator.compare(
points.block, iOffset + dimStart, points.block, jOffset + dimStart);
if (cmp != 0) { if (cmp != 0) {
return cmp; return cmp;
} }
} }
int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength; return points.compareDataDimsAndDoc(i, j);
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return Arrays.compareUnsigned(
points.block,
iOffset,
iOffset + dataLength,
points.block,
jOffset,
jOffset + dataLength);
} }
@Override @Override
protected int comparePivot(int j) { protected int comparePivot(int j) {
if (skypedBytes < config.bytesPerDim) { if (skypedBytes < config.bytesPerDim) {
int jOffset = j * config.bytesPerDoc; int cmp = points.compareDim(j, scratch, 0, dimStart);
int cmp = dimComparator.compare(scratch, 0, points.block, jOffset + dimStart);
if (cmp != 0) { if (cmp != 0) {
return cmp; return cmp;
} }
} }
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength; return points.compareDataDimsAndDoc(j, scratch, config.bytesPerDim);
return Arrays.compareUnsigned(
scratch,
config.bytesPerDim,
config.bytesPerDim + dataLength,
points.block,
jOffset,
jOffset + dataLength);
} }
}; };
} }
@ -538,13 +491,7 @@ public final class BKDRadixSelector {
@Override @Override
protected int byteAt(int i, int k) { protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k; assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) { return points.byteAt(i, k < dimCmpBytes ? dimOffset + k : dataOffset + k);
// dim bytes
return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
} else {
// data bytes
return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
}
} }
@Override @Override
@ -556,11 +503,6 @@ public final class BKDRadixSelector {
protected Sorter getFallbackSorter(int k) { protected Sorter getFallbackSorter(int k) {
final int skypedBytes = k + commonPrefixLength; final int skypedBytes = k + commonPrefixLength;
final int dimStart = dim * config.bytesPerDim; final int dimStart = dim * config.bytesPerDim;
// data length is composed by the data dimensions plus the docID
final int dataLength =
(config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
final ByteArrayComparator dimComparator =
ArrayUtil.getUnsignedComparator(config.bytesPerDim);
return new IntroSorter() { return new IntroSorter() {
@Override @Override
@ -571,61 +513,31 @@ public final class BKDRadixSelector {
@Override @Override
protected void setPivot(int i) { protected void setPivot(int i) {
if (skypedBytes < config.bytesPerDim) { if (skypedBytes < config.bytesPerDim) {
System.arraycopy( points.copyDim(i, dimStart, scratch, 0);
points.block,
i * config.bytesPerDoc + dim * config.bytesPerDim,
scratch,
0,
config.bytesPerDim);
} }
System.arraycopy( points.copyDataDimsAndDoc(i, scratch, config.bytesPerDim);
points.block,
i * config.bytesPerDoc + config.packedIndexBytesLength,
scratch,
config.bytesPerDim,
dataLength);
} }
@Override @Override
protected int compare(int i, int j) { protected int compare(int i, int j) {
if (skypedBytes < config.bytesPerDim) { if (skypedBytes < config.bytesPerDim) {
int iOffset = i * config.bytesPerDoc; final int cmp = points.compareDim(i, j, dimStart);
int jOffset = j * config.bytesPerDoc;
int cmp =
dimComparator.compare(
points.block, iOffset + dimStart, points.block, jOffset + dimStart);
if (cmp != 0) { if (cmp != 0) {
return cmp; return cmp;
} }
} }
int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength; return points.compareDataDimsAndDoc(i, j);
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return Arrays.compareUnsigned(
points.block,
iOffset,
iOffset + dataLength,
points.block,
jOffset,
jOffset + dataLength);
} }
@Override @Override
protected int comparePivot(int j) { protected int comparePivot(int j) {
if (skypedBytes < config.bytesPerDim) { if (skypedBytes < config.bytesPerDim) {
int jOffset = j * config.bytesPerDoc; int cmp = points.compareDim(j, scratch, 0, dimStart);
int cmp = dimComparator.compare(scratch, 0, points.block, jOffset + dimStart);
if (cmp != 0) { if (cmp != 0) {
return cmp; return cmp;
} }
} }
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength; return points.compareDataDimsAndDoc(j, scratch, config.bytesPerDim);
return Arrays.compareUnsigned(
scratch,
config.bytesPerDim,
config.bytesPerDim + dataLength,
points.block,
jOffset,
jOffset + dataLength);
} }
}; };
} }

View File

@ -16,8 +16,7 @@
*/ */
package org.apache.lucene.util.bkd; package org.apache.lucene.util.bkd;
import org.apache.lucene.util.BitUtil; import java.util.function.IntFunction;
import org.apache.lucene.util.BytesRef;
/** /**
* Utility class to read buffered points from in-heap arrays. * Utility class to read buffered points from in-heap arrays.
@ -26,22 +25,13 @@ import org.apache.lucene.util.BytesRef;
*/ */
public final class HeapPointReader implements PointReader { public final class HeapPointReader implements PointReader {
private int curRead; private int curRead;
final byte[] block; private final int end;
final BKDConfig config; private final IntFunction<PointValue> points;
final int end;
private final HeapPointValue pointValue;
public HeapPointReader(BKDConfig config, byte[] block, int start, int end) { HeapPointReader(IntFunction<PointValue> points, int start, int end) {
this.block = block;
curRead = start - 1; curRead = start - 1;
this.end = end; this.end = end;
this.config = config; this.points = points;
if (start < end) {
this.pointValue = new HeapPointValue(config, block);
} else {
// no values
this.pointValue = null;
}
} }
@Override @Override
@ -52,46 +42,9 @@ public final class HeapPointReader implements PointReader {
@Override @Override
public PointValue pointValue() { public PointValue pointValue() {
pointValue.setOffset(curRead * config.bytesPerDoc); return points.apply(curRead);
return pointValue;
} }
@Override @Override
public void close() {} public void close() {}
/** Reusable implementation for a point value on-heap */
static class HeapPointValue implements PointValue {
final BytesRef packedValue;
final BytesRef packedValueDocID;
final int packedValueLength;
HeapPointValue(BKDConfig config, byte[] value) {
this.packedValueLength = config.packedBytesLength;
this.packedValue = new BytesRef(value, 0, packedValueLength);
this.packedValueDocID = new BytesRef(value, 0, config.bytesPerDoc);
}
/** Sets a new value by changing the offset. */
public void setOffset(int offset) {
packedValue.offset = offset;
packedValueDocID.offset = offset;
}
@Override
public BytesRef packedValue() {
return packedValue;
}
@Override
public int docID() {
int position = packedValueDocID.offset + packedValueLength;
return (int) BitUtil.VH_BE_INT.get(packedValueDocID.bytes, position);
}
@Override
public BytesRef packedValueDocIDBytes() {
return packedValueDocID;
}
}
} }

View File

@ -17,6 +17,7 @@
package org.apache.lucene.util.bkd; package org.apache.lucene.util.bkd;
import java.util.Arrays; import java.util.Arrays;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
@ -26,22 +27,26 @@ import org.apache.lucene.util.BytesRef;
* @lucene.internal * @lucene.internal
*/ */
public final class HeapPointWriter implements PointWriter { public final class HeapPointWriter implements PointWriter {
public final byte[] block; private final byte[] block;
final int size; final int size;
final BKDConfig config; private final BKDConfig config;
private final byte[] scratch; private final byte[] scratch;
private final ArrayUtil.ByteArrayComparator dimComparator;
// length is composed by the data dimensions plus the docID
private final int dataDimsAndDocLength;
private int nextWrite; private int nextWrite;
private boolean closed; private boolean closed;
private final HeapPointValue pointValue;
private HeapPointReader.HeapPointValue pointValue;
public HeapPointWriter(BKDConfig config, int size) { public HeapPointWriter(BKDConfig config, int size) {
this.config = config; this.config = config;
this.block = new byte[config.bytesPerDoc * size]; this.block = new byte[config.bytesPerDoc * size];
this.size = size; this.size = size;
this.dimComparator = ArrayUtil.getUnsignedComparator(config.bytesPerDim);
this.dataDimsAndDocLength = config.bytesPerDoc - config.packedIndexBytesLength;
this.scratch = new byte[config.bytesPerDoc]; this.scratch = new byte[config.bytesPerDoc];
if (size > 0) { if (size > 0) {
pointValue = new HeapPointReader.HeapPointValue(config, block); pointValue = new HeapPointValue(config, block);
} else { } else {
// no values // no values
pointValue = null; pointValue = null;
@ -65,10 +70,9 @@ public final class HeapPointWriter implements PointWriter {
+ packedValue.length + packedValue.length
+ "]"; + "]";
assert nextWrite < size : "nextWrite=" + (nextWrite + 1) + " vs size=" + size; assert nextWrite < size : "nextWrite=" + (nextWrite + 1) + " vs size=" + size;
System.arraycopy( final int position = nextWrite * config.bytesPerDoc;
packedValue, 0, block, nextWrite * config.bytesPerDoc, config.packedBytesLength); System.arraycopy(packedValue, 0, block, position, config.packedBytesLength);
int position = nextWrite * config.bytesPerDoc + config.packedBytesLength; BitUtil.VH_BE_INT.set(block, position + config.packedBytesLength, docID);
BitUtil.VH_BE_INT.set(block, position, docID);
nextWrite++; nextWrite++;
} }
@ -76,27 +80,23 @@ public final class HeapPointWriter implements PointWriter {
public void append(PointValue pointValue) { public void append(PointValue pointValue) {
assert closed == false : "point writer is already closed"; assert closed == false : "point writer is already closed";
assert nextWrite < size : "nextWrite=" + (nextWrite + 1) + " vs size=" + size; assert nextWrite < size : "nextWrite=" + (nextWrite + 1) + " vs size=" + size;
BytesRef packedValueDocID = pointValue.packedValueDocIDBytes(); final BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
assert packedValueDocID.length == config.bytesPerDoc assert packedValueDocID.length == config.bytesPerDoc
: "[packedValue] must have length [" : "[packedValue] must have length ["
+ (config.bytesPerDoc) + (config.bytesPerDoc)
+ "] but was [" + "] but was ["
+ packedValueDocID.length + packedValueDocID.length
+ "]"; + "]";
final int position = nextWrite * config.bytesPerDoc;
System.arraycopy( System.arraycopy(
packedValueDocID.bytes, packedValueDocID.bytes, packedValueDocID.offset, block, position, config.bytesPerDoc);
packedValueDocID.offset,
block,
nextWrite * config.bytesPerDoc,
config.bytesPerDoc);
nextWrite++; nextWrite++;
} }
public void swap(int i, int j) { /** Swaps the point at point {@code i} with the point at position {@code j} */
void swap(int i, int j) {
int indexI = i * config.bytesPerDoc; final int indexI = i * config.bytesPerDoc;
int indexJ = j * config.bytesPerDoc; final int indexJ = j * config.bytesPerDoc;
// scratch1 = values[i] // scratch1 = values[i]
System.arraycopy(block, indexI, scratch, 0, config.bytesPerDoc); System.arraycopy(block, indexI, scratch, 0, config.bytesPerDoc);
// values[i] = values[j] // values[i] = values[j]
@ -105,19 +105,100 @@ public final class HeapPointWriter implements PointWriter {
System.arraycopy(scratch, 0, block, indexJ, config.bytesPerDoc); System.arraycopy(scratch, 0, block, indexJ, config.bytesPerDoc);
} }
/** Return the byte at position {@code k} of the point at position {@code i} */
int byteAt(int i, int k) {
return block[i * config.bytesPerDoc + k] & 0xff;
}
/**
* Copy the dimension {@code dim} of the point at position {@code i} in the provided {@code bytes}
* at the given offset
*/
void copyDim(int i, int dim, byte[] bytes, int offset) {
System.arraycopy(block, i * config.bytesPerDoc + dim, bytes, offset, config.bytesPerDim);
}
/**
* Copy the data dimensions and doc value of the point at position {@code i} in the provided
* {@code bytes} at the given offset
*/
void copyDataDimsAndDoc(int i, byte[] bytes, int offset) {
System.arraycopy(
block,
i * config.bytesPerDoc + config.packedIndexBytesLength,
bytes,
offset,
dataDimsAndDocLength);
}
/**
* Compares the dimension {@code dim} value of the point at position {@code i} with the point at
* position {@code j}
*/
int compareDim(int i, int j, int dim) {
final int iOffset = i * config.bytesPerDoc + dim;
final int jOffset = j * config.bytesPerDoc + dim;
return compareDim(block, iOffset, block, jOffset);
}
/**
* Compares the dimension {@code dim} value of the point at position {@code j} with the provided
* value
*/
int compareDim(int j, byte[] dimValue, int offset, int dim) {
final int jOffset = j * config.bytesPerDoc + dim;
return compareDim(dimValue, offset, block, jOffset);
}
private int compareDim(byte[] blockI, int offsetI, byte[] blockJ, int offsetJ) {
return dimComparator.compare(blockI, offsetI, blockJ, offsetJ);
}
/**
* Compares the data dimensions and doc values of the point at position {@code i} with the point
* at position {@code j}
*/
int compareDataDimsAndDoc(int i, int j) {
final int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
final int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return compareDataDimsAndDoc(block, iOffset, block, jOffset);
}
/**
* Compares the data dimensions and doc values of the point at position {@code j} with the
* provided value
*/
int compareDataDimsAndDoc(int j, byte[] dataDimsAndDocs, int offset) {
final int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return compareDataDimsAndDoc(dataDimsAndDocs, offset, block, jOffset);
}
private int compareDataDimsAndDoc(byte[] blockI, int offsetI, byte[] blockJ, int offsetJ) {
return Arrays.compareUnsigned(
blockI,
offsetI,
offsetI + dataDimsAndDocLength,
blockJ,
offsetJ,
offsetJ + dataDimsAndDocLength);
}
/** Computes the cardinality of the points between {@code from} tp {@code to} */
public int computeCardinality(int from, int to, int[] commonPrefixLengths) { public int computeCardinality(int from, int to, int[] commonPrefixLengths) {
int leafCardinality = 1; int leafCardinality = 1;
for (int i = from + 1; i < to; i++) { for (int i = from + 1; i < to; i++) {
final int pointOffset = (i - 1) * config.bytesPerDoc;
final int nextPointOffset = pointOffset + config.bytesPerDoc;
for (int dim = 0; dim < config.numDims; dim++) { for (int dim = 0; dim < config.numDims; dim++) {
final int start = dim * config.bytesPerDim + commonPrefixLengths[dim]; final int start = dim * config.bytesPerDim + commonPrefixLengths[dim];
final int end = dim * config.bytesPerDim + config.bytesPerDim; final int end = dim * config.bytesPerDim + config.bytesPerDim;
if (Arrays.mismatch( if (Arrays.mismatch(
block, block,
i * config.bytesPerDoc + start, nextPointOffset + start,
i * config.bytesPerDoc + end, nextPointOffset + end,
block, block,
(i - 1) * config.bytesPerDoc + start, pointOffset + start,
(i - 1) * config.bytesPerDoc + end) pointOffset + end)
!= -1) { != -1) {
leafCardinality++; leafCardinality++;
break; break;
@ -139,7 +220,8 @@ public final class HeapPointWriter implements PointWriter {
: "start=" + start + " length=" + length + " docIDs.length=" + size; : "start=" + start + " length=" + length + " docIDs.length=" + size;
assert start + length <= nextWrite assert start + length <= nextWrite
: "start=" + start + " length=" + length + " nextWrite=" + nextWrite; : "start=" + start + " length=" + length + " nextWrite=" + nextWrite;
return new HeapPointReader(config, block, (int) start, Math.toIntExact(start + length)); return new HeapPointReader(
this::getPackedValueSlice, (int) start, Math.toIntExact(start + length));
} }
@Override @Override
@ -154,4 +236,40 @@ public final class HeapPointWriter implements PointWriter {
public String toString() { public String toString() {
return "HeapPointWriter(count=" + nextWrite + " size=" + size + ")"; return "HeapPointWriter(count=" + nextWrite + " size=" + size + ")";
} }
/** Reusable implementation for a point value on-heap */
private static class HeapPointValue implements PointValue {
private final BytesRef packedValue;
private final BytesRef packedValueDocID;
private final int packedValueLength;
HeapPointValue(BKDConfig config, byte[] value) {
this.packedValueLength = config.packedBytesLength;
this.packedValue = new BytesRef(value, 0, packedValueLength);
this.packedValueDocID = new BytesRef(value, 0, config.bytesPerDoc);
}
/** Sets a new value by changing the offset. */
void setOffset(int offset) {
packedValue.offset = offset;
packedValueDocID.offset = offset;
}
@Override
public BytesRef packedValue() {
return packedValue;
}
@Override
public int docID() {
int position = packedValueDocID.offset + packedValueLength;
return (int) BitUtil.VH_BE_INT.get(packedValueDocID.bytes, position);
}
@Override
public BytesRef packedValueDocIDBytes() {
return packedValueDocID;
}
}
} }