Move vector search from IndexInput to RandomAccessInput (#13938)

This commit is contained in:
Anh Dung Bui 2024-11-08 17:31:28 +09:00
parent 494b16063e
commit b97aadb925
14 changed files with 119 additions and 50 deletions

View File

@ -63,8 +63,7 @@ final class EndiannessReverserIndexInput extends FilterIndexInput {
public void readFloats(float[] dst, int offset, int length) throws IOException {
in.readFloats(dst, offset, length);
for (int i = 0; i < length; ++i) {
dst[offset + i] =
Float.intBitsToFloat(Integer.reverseBytes(Float.floatToRawIntBits(dst[offset + i])));
dst[offset + i] = revertFloat(dst[offset + i]);
}
}
@ -106,6 +105,14 @@ final class EndiannessReverserIndexInput extends FilterIndexInput {
return in.readByte(pos);
}
@Override
public void readFloats(long pos, float[] floats, int offset, int length) throws IOException {
in.readFloats(pos, floats, offset, length);
for (int i = 0; i < length; ++i) {
floats[offset + i] = revertFloat(floats[offset + i]);
}
}
@Override
public short readShort(long pos) throws IOException {
return Short.reverseBytes(in.readShort(pos));
@ -120,5 +127,19 @@ final class EndiannessReverserIndexInput extends FilterIndexInput {
public long readLong(long pos) throws IOException {
return Long.reverseBytes(in.readLong(pos));
}
@Override
public Object clone() {
try {
return super.clone();
} catch (CloneNotSupportedException e) {
throw new Error(
"This cannot happen: Failing to clone EndiannessReverserRandomAccessInput", e);
}
}
}
private static float revertFloat(float value) {
return Float.intBitsToFloat(Integer.reverseBytes(Float.floatToRawIntBits(value)));
}
}

View File

@ -16,14 +16,14 @@
*/
package org.apache.lucene.codecs.lucene95;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
/**
* Implementors can return the IndexInput from which their values are read. For use by vector
* Implementors can return the RandomAccessInput from which their values are read. For use by vector
* quantizers.
*/
public interface HasIndexSlice {
/** Returns an IndexInput from which to read this instance's values. */
IndexInput getSlice();
/** Returns a RandomAccessInput from which to read this instance's values. */
RandomAccessInput getSlice();
}

View File

@ -37,7 +37,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final RandomAccessInput slice;
protected int lastOrd = -1;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
@ -48,7 +48,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
OffHeapByteVectorValues(
int dimension,
int size,
IndexInput slice,
RandomAccessInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
@ -82,13 +82,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
}
@Override
public IndexInput getSlice() {
public RandomAccessInput getSlice() {
return slice;
}
private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
slice.readBytes(
(long) targetOrd * byteSize, byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
public static OffHeapByteVectorValues load(
@ -104,7 +104,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction);
}
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
RandomAccessInput bytesSlice = vectorData.randomAccessSlice(vectorDataOffset, vectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(
dimension,
@ -133,7 +133,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
public DenseOffHeapVectorValues(
int dimension,
int size,
IndexInput slice,
RandomAccessInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction vectorSimilarityFunction) {
@ -143,7 +143,12 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
dimension,
size,
(RandomAccessInput) slice.clone(),
byteSize,
flatVectorsScorer,
similarityFunction);
}
@Override
@ -186,7 +191,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
public SparseOffHeapVectorValues(
OrdToDocDISIReaderConfiguration configuration,
IndexInput dataIn,
IndexInput slice,
RandomAccessInput slice,
int dimension,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
@ -220,7 +225,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement
return new SparseOffHeapVectorValues(
configuration,
dataIn,
slice.clone(),
(RandomAccessInput) slice.clone(),
dimension,
byteSize,
flatVectorsScorer,

View File

@ -36,7 +36,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final RandomAccessInput slice;
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
@ -46,7 +46,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
OffHeapFloatVectorValues(
int dimension,
int size,
IndexInput slice,
RandomAccessInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
@ -70,7 +70,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
}
@Override
public IndexInput getSlice() {
public RandomAccessInput getSlice() {
return slice;
}
@ -79,8 +79,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
if (lastOrd == targetOrd) {
return value;
}
slice.seek((long) targetOrd * byteSize);
slice.readFloats(value, 0, value.length);
slice.readFloats((long) targetOrd * byteSize, value, 0, value.length);
lastOrd = targetOrd;
return value;
}
@ -98,7 +97,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) {
return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction);
}
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
RandomAccessInput bytesSlice = vectorData.randomAccessSlice(vectorDataOffset, vectorDataLength);
int byteSize = dimension * Float.BYTES;
if (configuration.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
@ -129,7 +128,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
public DenseOffHeapVectorValues(
int dimension,
int size,
IndexInput slice,
RandomAccessInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
@ -139,7 +138,12 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
dimension,
size,
(RandomAccessInput) slice.clone(),
byteSize,
flatVectorsScorer,
similarityFunction);
}
@Override
@ -187,7 +191,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
public SparseOffHeapVectorValues(
OrdToDocDISIReaderConfiguration configuration,
IndexInput dataIn,
IndexInput slice,
RandomAccessInput slice,
int dimension,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
@ -215,7 +219,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme
return new SparseOffHeapVectorValues(
configuration,
dataIn,
slice.clone(),
(RandomAccessInput) slice.clone(),
dimension,
byteSize,
flatVectorsScorer,

View File

@ -213,8 +213,8 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
public float score(int vectorOrdinal) throws IOException {
// get compressed vector, in Lucene99, vector values are stored and have a single value for
// offset correction
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
long pos = (long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES);
values.getSlice().readBytes(pos, compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;

View File

@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
@ -46,7 +47,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
protected final FlatVectorsScorer vectorsScorer;
protected final boolean compress;
protected final IndexInput slice;
protected final RandomAccessInput slice;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
@ -93,7 +94,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
boolean compress,
IndexInput slice) {
RandomAccessInput slice) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
@ -131,9 +132,9 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
if (lastOrd == targetOrd) {
return binaryValue;
}
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
slice.readFloats(scoreCorrectionConstant, 0, 1);
long pos = (long) targetOrd * byteSize;
slice.readBytes(pos, byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
slice.readFloats(pos + numBytes, scoreCorrectionConstant, 0, 1);
decompressBytes(binaryValue, numBytes);
lastOrd = targetOrd;
return binaryValue;
@ -144,13 +145,12 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
if (lastOrd == targetOrd) {
return scoreCorrectionConstant[0];
}
slice.seek(((long) targetOrd * byteSize) + numBytes);
slice.readFloats(scoreCorrectionConstant, 0, 1);
slice.readFloats(((long) targetOrd * byteSize) + numBytes, scoreCorrectionConstant, 0, 1);
return scoreCorrectionConstant[0];
}
@Override
public IndexInput getSlice() {
public RandomAccessInput getSlice() {
return slice;
}
@ -174,9 +174,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
if (configuration.isEmpty()) {
return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
}
IndexInput bytesSlice =
vectorData.slice(
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
RandomAccessInput bytesSlice =
vectorData.randomAccessSlice(quantizedVectorDataOffset, quantizedVectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(
dimension,
@ -213,7 +212,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
boolean compress,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
IndexInput slice) {
RandomAccessInput slice) {
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
}
@ -226,7 +225,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
compress,
similarityFunction,
vectorsScorer,
slice.clone());
(RandomAccessInput) slice.clone());
}
@Override
@ -275,7 +274,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
IndexInput dataIn,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
IndexInput slice)
RandomAccessInput slice)
throws IOException {
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
this.configuration = configuration;
@ -300,7 +299,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
dataIn,
similarityFunction,
vectorsScorer,
slice.clone());
(RandomAccessInput) slice.clone());
}
@Override

View File

@ -258,6 +258,12 @@ public abstract class BufferedIndexInput extends IndexInput implements RandomAcc
return buffer.get((int) index);
}
@Override
public void readFloats(long pos, float[] dst, int offset, int len) throws IOException {
seek(pos);
readFloats(dst, offset, len);
}
@Override
public void readBytes(long pos, byte[] bytes, int offset, int len) throws IOException {
if (len <= bufferSize) {

View File

@ -261,6 +261,12 @@ public final class ByteBuffersDataInput extends DataInput
}
}
@Override
public void readFloats(long pos, float[] floats, int offset, int length) throws IOException {
seek(pos);
readFloats(floats, offset, length);
}
@Override
public short readShort(long pos) {
long absPos = offset + pos;

View File

@ -175,6 +175,12 @@ public final class ByteBuffersIndexInput extends IndexInput implements RandomAcc
in.readBytes(pos, bytes, offset, length);
}
@Override
public void readFloats(long pos, float[] floats, int offset, int length) throws IOException {
ensureOpen();
in.readFloats(pos, floats, offset, length);
}
@Override
public short readShort(long pos) throws IOException {
ensureOpen();

View File

@ -184,6 +184,13 @@ public abstract class IndexInput extends DataInput implements Closeable {
slice.readBytes(bytes, offset, length);
}
@Override
public void readFloats(long pos, float[] floats, int offset, int length)
throws IOException {
slice.seek(pos);
slice.readFloats(floats, offset, length);
}
@Override
public short readShort(long pos) throws IOException {
slice.seek(pos);

View File

@ -23,7 +23,7 @@ import org.apache.lucene.util.BitUtil; // javadocs
* Random Access Index API. Unlike {@link IndexInput}, this has no concept of file position, all
* reads are absolute. However, like IndexInput, it is only intended for use by a single thread.
*/
public interface RandomAccessInput {
public interface RandomAccessInput extends Cloneable {
/** The number of bytes in the file. */
long length();
@ -47,6 +47,14 @@ public interface RandomAccessInput {
}
}
/**
* Reads a specified number of floats starting at a given position into an array at the specified
* offset.
*
* @see DataInput#readFloats
*/
void readFloats(long pos, float[] floats, int offset, int length) throws IOException;
/**
* Reads a short (LE byte order) at the given position in the file
*
@ -77,4 +85,6 @@ public interface RandomAccessInput {
* @see IndexInput#prefetch
*/
default void prefetch(long offset, long length) throws IOException {}
Object clone();
}

View File

@ -20,7 +20,7 @@ import java.io.IOException;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
/**
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
@ -52,7 +52,7 @@ public abstract class QuantizedByteVectorValues extends ByteVectorValues impleme
}
@Override
public IndexInput getSlice() {
public RandomAccessInput getSlice() {
return null;
}
}

View File

@ -43,6 +43,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
@ -133,8 +134,12 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
}
@Override
public IndexInput getSlice() {
return in;
public RandomAccessInput getSlice() {
try {
return in.randomAccessSlice(0, in.length());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override

View File

@ -22,7 +22,7 @@ import java.util.Random;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
/** A reader of vector values that samples a subset of the vectors. */
@ -53,7 +53,7 @@ public class SampleReader extends FloatVectorValues implements HasIndexSlice {
}
@Override
public IndexInput getSlice() {
public RandomAccessInput getSlice() {
return ((HasIndexSlice) origin).getSlice();
}