Replace BytesRef usages in byte vectors API with byte[] (#12102)

The main classes involved are ByteVectorValues, KnnByteVectorField and KnnByteVectorQuery. It becomes quite natural to simplify things further and use byte[] in the following methods too: ByteVectorValues#vectorValue, KnnVectorReader#search, LeafReader#searchNearestVectors, HNSWGraphSearcher#search, VectorSimilarityFunction#compare, VectorUtil#cosine, VectorUtil#squareDistance, VectorUtil#dotProduct, VectorUtil#dotProductScore
This commit is contained in:
Luca Cavanna 2023-01-23 22:06:00 +01:00 committed by GitHub
parent f8ee852696
commit 4594400216
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 262 additions and 380 deletions

View File

@ -41,7 +41,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -279,7 +278,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -42,7 +42,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -270,7 +269,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -41,7 +41,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -266,7 +265,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -41,7 +41,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -302,7 +301,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

View File

@ -25,18 +25,17 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
implements RandomAccessVectorValues<byte[]> {
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final BytesRef binaryValue;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
@ -46,7 +45,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
binaryValue = byteBuffer.array();
}
@Override
@ -60,7 +59,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
public byte[] vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
@ -99,7 +98,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
@ -125,7 +124,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
public RandomAccessVectorValues<byte[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -171,7 +170,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
@ -194,7 +193,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
public RandomAccessVectorValues<byte[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@ -241,7 +240,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@ -261,12 +260,12 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
public RandomAccessVectorValues<byte[]> copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
public byte[] vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -185,7 +185,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
// write vector
float[] vectorValue = vectors.vectorValue();
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
output.writeBytes(binaryVector.array(), binaryVector.limit());
docIds[count] = docV;
}

View File

@ -179,7 +179,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
// write vector
float[] vectorValue = vectors.vectorValue();
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
output.writeBytes(binaryVector.array(), binaryVector.limit());
docsWithField.add(docV);
}
return docsWithField;

View File

@ -187,7 +187,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
// write vector
float[] vectorValue = vectors.vectorValue();
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
output.writeBytes(binaryVector.array(), binaryVector.limit());
docsWithField.add(docV);
}
return docsWithField;

View File

@ -43,7 +43,6 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
@ -197,17 +196,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
BytesRef vector = (BytesRef) v;
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
byte[] vector = (byte[]) v;
vectorData.writeBytes(vector, vector.length);
}
}
@ -267,11 +265,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
return vectorDataOffset;
}
@ -279,8 +276,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
BytesRef vector = (BytesRef) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, vector.length);
}
return vectorDataOffset;
}
@ -423,7 +420,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
HnswGraphBuilder<byte[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
@ -596,9 +593,9 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
BytesRef binaryValue = byteVectorValues.vectorValue();
byte[] binaryValue = byteVectorValues.vectorValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
@ -619,7 +616,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
// write vector
float[] vectorValue = floatVectorValues.vectorValue();
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
output.writeBytes(binaryVector.array(), binaryVector.limit());
docsWithField.add(docV);
}
return docsWithField;
@ -644,10 +641,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {

View File

@ -222,7 +222,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
ByteVectorValues values = getByteVectorValues(field);
if (target.length != values.dimension()) {
@ -250,7 +250,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
break;
}
BytesRef vector = values.vectorValue();
byte[] vector = values.vectorValue();
float score = vectorSimilarity.compare(vector, target);
topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++;
@ -458,9 +458,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public BytesRef vectorValue() {
public byte[] vectorValue() {
binaryValue.bytes = values[curOrd];
return binaryValue;
return binaryValue.bytes;
}
@Override

View File

@ -30,7 +30,6 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.RamUsageEstimator;
/**
@ -97,7 +96,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
};
@ -192,7 +191,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}

View File

@ -24,7 +24,6 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedSPILoader;
/**
@ -112,7 +111,7 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}

View File

@ -27,7 +27,6 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable {
@ -117,7 +116,7 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/
public abstract TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}.

View File

@ -30,7 +30,6 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
/** Writes vectors to an index. */
public abstract class KnnVectorsWriter implements Accountable, Closeable {
@ -49,8 +48,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case BYTE:
KnnFieldVectorsWriter<BytesRef> byteWriter =
(KnnFieldVectorsWriter<BytesRef>) addField(fieldInfo);
KnnFieldVectorsWriter<byte[]> byteWriter =
(KnnFieldVectorsWriter<byte[]>) addField(fieldInfo);
ByteVectorValues mergedBytes =
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
for (int doc = mergedBytes.nextDoc();
@ -262,7 +261,7 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
return current.values.vectorValue();
}

View File

@ -43,7 +43,6 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -310,7 +309,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

View File

@ -186,17 +186,16 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
BytesRef vector = (BytesRef) v;
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
byte[] vector = (byte[]) v;
vectorData.writeBytes(vector, vector.length);
}
}
@ -258,11 +257,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
return vectorDataOffset;
}
@ -270,8 +268,8 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
BytesRef vector = (BytesRef) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, vector.length);
}
return vectorDataOffset;
}
@ -435,7 +433,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
HnswGraphBuilder<byte[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
@ -646,9 +644,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
BytesRef binaryValue = byteVectorValues.vectorValue();
byte[] binaryValue = byteVectorValues.vectorValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
@ -669,7 +667,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
// write vector
float[] value = floatVectorValues.vectorValue();
buffer.asFloatBuffer().put(value);
output.writeBytes(buffer.array(), 0, buffer.limit());
output.writeBytes(buffer.array(), buffer.limit());
docsWithField.add(docV);
}
return docsWithField;
@ -694,11 +692,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {

View File

@ -25,18 +25,17 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
implements RandomAccessVectorValues<byte[]> {
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final BytesRef binaryValue;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
@ -46,7 +45,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
binaryValue = byteBuffer.array();
}
@Override
@ -60,7 +59,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
public byte[] vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
@ -99,7 +98,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
@ -125,7 +124,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
public RandomAccessVectorValues<byte[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -171,7 +170,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
@ -194,7 +193,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
public RandomAccessVectorValues<byte[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@ -241,7 +240,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@ -261,12 +260,12 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
public RandomAccessVectorValues<byte[]> copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
public byte[] vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -36,7 +36,6 @@ import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
/**
@ -273,7 +272,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
}

View File

@ -22,7 +22,6 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
/**
* A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
@ -38,7 +37,7 @@ import org.apache.lucene.util.BytesRef;
*/
public class KnnByteVectorField extends Field {
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
private static FieldType createType(byte[] v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
@ -67,7 +66,7 @@ public class KnnByteVectorField extends Field {
* @param k The number of nearest neighbors to gather
* @return A new vector query
*/
public static Query newVectorQuery(String field, BytesRef queryVector, int k) {
public static Query newVectorQuery(String field, byte[] queryVector, int k) {
return new KnnByteVectorQuery(field, queryVector, k);
}
@ -99,7 +98,7 @@ public class KnnByteVectorField extends Field {
* dimension &gt; 1024.
*/
public KnnByteVectorField(
String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
@ -114,7 +113,7 @@ public class KnnByteVectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnByteVectorField(String name, BytesRef vector) {
public KnnByteVectorField(String name, byte[] vector) {
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
}
@ -128,7 +127,7 @@ public class KnnByteVectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnByteVectorField(String name, BytesRef vector, FieldType fieldType) {
public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
super(name, fieldType);
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
@ -141,8 +140,8 @@ public class KnnByteVectorField extends Field {
}
/** Return the vector value of this field */
public BytesRef vectorValue() {
return (BytesRef) fieldsData;
public byte[] vectorValue() {
return (byte[]) fieldsData;
}
/**
@ -150,7 +149,7 @@ public class KnnByteVectorField extends Field {
*
* @param value the value to set; must not be null, and length must match the field type
*/
public void setVectorValue(BytesRef value) {
public void setVectorValue(byte[] value) {
if (value == null) {
throw new IllegalArgumentException("value must not be null");
}

View File

@ -19,7 +19,6 @@ package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
/**
* This class provides access to per-document floating point vector values indexed as {@link
@ -57,5 +56,5 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
*
* @return the vector value
*/
public abstract BytesRef vectorValue() throws IOException;
public abstract byte[] vectorValue() throws IOException;
}

View File

@ -27,7 +27,6 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** LeafReader implemented by codec APIs. */
public abstract class CodecReader extends LeafReader {
@ -257,7 +256,7 @@ public abstract class CodecReader extends LeafReader {
@Override
public final TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {

View File

@ -20,7 +20,6 @@ package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
abstract class DocValuesLeafReader extends LeafReader {
@Override
@ -66,7 +65,7 @@ abstract class DocValuesLeafReader extends LeafReader {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -508,7 +508,7 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
}
@Override
public BytesRef vectorValue() throws IOException {
public byte[] vectorValue() throws IOException {
return vectorValues.vectorValue();
}

View File

@ -364,7 +364,7 @@ public abstract class FilterLeafReader extends LeafReader {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}

View File

@ -961,7 +961,7 @@ final class IndexingChain implements Accountable {
int docID, PerField pf, VectorEncoding vectorEncoding, IndexableField field)
throws IOException {
switch (vectorEncoding) {
case BYTE -> ((KnnFieldVectorsWriter<BytesRef>) pf.knnFieldVectorsWriter)
case BYTE -> ((KnnFieldVectorsWriter<byte[]>) pf.knnFieldVectorsWriter)
.addValue(docID, ((KnnByteVectorField) field).vectorValue());
case FLOAT32 -> ((KnnFieldVectorsWriter<float[]>) pf.knnFieldVectorsWriter)
.addValue(docID, ((KnnVectorField) field).vectorValue());

View File

@ -21,7 +21,6 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* {@code LeafReader} is an abstract class, providing an interface for accessing an index. Search of
@ -270,7 +269,7 @@ public abstract class LeafReader extends IndexReader {
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Get the {@link FieldInfos} describing all fields in this reader.

View File

@ -29,7 +29,6 @@ import java.util.TreeMap;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
/**
@ -428,7 +427,7 @@ public class ParallelLeafReader extends LeafReader {
@Override
public TopDocs searchNearestVectors(
String fieldName, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
String fieldName, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);

View File

@ -30,7 +30,6 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* Wraps arbitrary readers for merging. Note that this can cause slow and memory-intensive merges.
@ -180,7 +179,7 @@ public final class SlowCodecReaderWrapper {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}

View File

@ -34,6 +34,7 @@ import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
@ -284,7 +285,9 @@ public final class SortingCodecReader extends FilterCodecReader {
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID);
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.vectorValue());
binaryVectors[newDocID] =
new BytesRef(
ArrayUtil.copyOfSubArray(delegate.vectorValue(), 0, delegate.vectorValue().length));
}
}
@ -299,8 +302,8 @@ public final class SortingCodecReader extends FilterCodecReader {
}
@Override
public BytesRef vectorValue() throws IOException {
return binaryVectors[docId];
public byte[] vectorValue() throws IOException {
return binaryVectors[docId].bytes;
}
@Override
@ -505,8 +508,7 @@ public final class SortingCodecReader extends FilterCodecReader {
}
@Override
public TopDocs search(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}

View File

@ -21,8 +21,6 @@ import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.dotProductScore;
import static org.apache.lucene.util.VectorUtil.squareDistance;
import org.apache.lucene.util.BytesRef;
/**
* Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors
@ -38,7 +36,7 @@ public enum VectorSimilarityFunction {
}
@Override
public float compare(BytesRef v1, BytesRef v2) {
public float compare(byte[] v1, byte[] v2) {
return 1 / (1 + squareDistance(v1, v2));
}
},
@ -57,7 +55,7 @@ public enum VectorSimilarityFunction {
}
@Override
public float compare(BytesRef v1, BytesRef v2) {
public float compare(byte[] v1, byte[] v2) {
return dotProductScore(v1, v2);
}
},
@ -75,7 +73,7 @@ public enum VectorSimilarityFunction {
}
@Override
public float compare(BytesRef v1, BytesRef v2) {
public float compare(byte[] v1, byte[] v2) {
return (1 + cosine(v1, v2)) / 2;
}
};
@ -99,5 +97,5 @@ public enum VectorSimilarityFunction {
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(BytesRef v1, BytesRef v2);
public abstract float compare(byte[] v1, byte[] v2);
}

View File

@ -17,18 +17,19 @@
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* Uses {@link KnnVectorsReader#search(String, BytesRef, int, Bits, int)} to perform nearest
* neighbour search.
* Uses {@link KnnVectorsReader#search(String, byte[], int, Bits, int)} to perform nearest neighbour
* search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
@ -43,7 +44,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final BytesRef target;
private final byte[] target;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
@ -54,7 +55,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnByteVectorQuery(String field, BytesRef target, int k) {
public KnnByteVectorQuery(String field, byte[] target, int k) {
this(field, target, k, null);
}
@ -68,7 +69,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnByteVectorQuery(String field, BytesRef target, int k, Query filter) {
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
}
@ -91,14 +92,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
@Override
public String toString(String field) {
return getClass().getSimpleName()
+ ":"
+ this.field
+ "["
+ target.bytes[target.offset]
+ ",...]["
+ k
+ "]";
return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]";
}
@Override
@ -106,18 +100,18 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
if (this == o) return true;
if (super.equals(o) == false) return false;
KnnByteVectorQuery that = (KnnByteVectorQuery) o;
return Objects.equals(target, that.target);
return Arrays.equals(target, that.target);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), target);
return Objects.hash(super.hashCode(), Arrays.hashCode(target));
}
/**
* @return the target query vector of the search. Each vector element is a byte.
*/
public BytesRef getTargetCopy() {
return BytesRef.deepCopyOf(target);
public byte[] getTargetCopy() {
return ArrayUtil.copyOfSubArray(target, 0, target.length);
}
}

View File

@ -22,7 +22,6 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
/**
* Computes the similarity score between a given query vector and different document vectors. This
@ -46,7 +45,7 @@ abstract class VectorScorer {
return new FloatVectorScorer(values, query, similarity);
}
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
throws IOException {
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
@ -63,11 +62,11 @@ abstract class VectorScorer {
abstract boolean advanceExact(int doc) throws IOException;
private static class ByteVectorScorer extends VectorScorer {
private final BytesRef query;
private final byte[] query;
private final ByteVectorValues values;
protected ByteVectorScorer(
ByteVectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
ByteVectorValues values, byte[] query, VectorSimilarityFunction similarity) {
super(similarity);
this.values = values;
this.query = query;

View File

@ -122,16 +122,15 @@ public final class VectorUtil {
}
/** Returns the cosine similarity between the two vectors. */
public static float cosine(BytesRef a, BytesRef b) {
public static float cosine(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
int sum = 0;
int norm1 = 0;
int norm2 = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
byte elem1 = a.bytes[aOffset++];
byte elem2 = b.bytes[bOffset++];
byte elem1 = a[i];
byte elem2 = b[i];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
@ -182,12 +181,11 @@ public final class VectorUtil {
}
/** Returns the sum of squared differences of the two vectors. */
public static float squareDistance(BytesRef a, BytesRef b) {
public static float squareDistance(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
int squareSum = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
int diff = a.bytes[aOffset++] - b.bytes[bOffset++];
int diff = a[i] - b[i];
squareSum += diff * diff;
}
return squareSum;
@ -251,12 +249,11 @@ public final class VectorUtil {
* @param b bytes containing another vector, of the same dimension
* @return the value of the dot product of the two vectors
*/
public static float dotProduct(BytesRef a, BytesRef b) {
public static float dotProduct(byte[] a, byte[] b) {
assert a.length == b.length;
int total = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
total += a.bytes[aOffset++] * b.bytes[bOffset++];
total += a[i] * b[i];
}
return total;
}
@ -268,7 +265,7 @@ public final class VectorUtil {
* @param b bytes containing another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public static float dotProductScore(BytesRef a, BytesRef b) {
public static float dotProductScore(byte[] a, byte[] b) {
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;

View File

@ -26,7 +26,6 @@ import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
@ -273,7 +272,7 @@ public final class HnswGraphBuilder<T> {
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException {
return switch (vectorEncoding) {
case BYTE -> isDiverse((BytesRef) vectors.vectorValue(candidate), neighbors, score);
case BYTE -> isDiverse((byte[]) vectors.vectorValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
};
}
@ -291,12 +290,12 @@ public final class HnswGraphBuilder<T> {
return true;
}
private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float score)
private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score)
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(
candidate, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
candidate, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
@ -322,7 +321,7 @@ public final class HnswGraphBuilder<T> {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidateIndex, (BytesRef) vectors.vectorValue(candidateNode), neighbors);
candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
case FLOAT32 -> isWorstNonDiverse(
candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
};
@ -344,12 +343,12 @@ public final class HnswGraphBuilder<T> {
}
private boolean isWorstNonDiverse(
int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;

View File

@ -24,7 +24,6 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;
@ -135,9 +134,9 @@ public class HnswGraphSearcher<T> {
* @return a priority queue holding the closest neighbors found
*/
public static NeighborQueue search(
BytesRef query,
byte[] query,
int topK,
RandomAccessVectorValues<BytesRef> vectors,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
@ -151,7 +150,7 @@ public class HnswGraphSearcher<T> {
+ " differs from field dimension: "
+ vectors.dimension());
}
HnswGraphSearcher<BytesRef> graphSearcher =
HnswGraphSearcher<byte[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
@ -281,7 +280,7 @@ public class HnswGraphSearcher<T> {
private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) {
return similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(ord));
return similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(ord));
} else {
return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
}

View File

@ -637,14 +637,17 @@ public class TestField extends LuceneTestCase {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
BytesRef br = newBytesRef(new byte[5]);
Field field = new KnnByteVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
byte[] b = new byte[5];
KnnByteVectorField field =
new KnnByteVectorField("binary", b, VectorSimilarityFunction.EUCLIDEAN);
assertNull(field.binaryValue());
assertArrayEquals(b, field.vectorValue());
expectThrows(
IllegalArgumentException.class,
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
float[] vector = new float[] {1, 2};
Field field2 = new KnnVectorField("float", vector);
assertEquals(br, field.binaryValue());
assertNull(field2.binaryValue());
doc.add(field);
doc.add(field2);
w.addDocument(doc);
@ -653,7 +656,7 @@ public class TestField extends LuceneTestCase {
assertEquals(1, binary.size());
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
assertNotNull(binary.vectorValue());
assertEquals(br, binary.vectorValue());
assertArrayEquals(b, binary.vectorValue());
assertEquals(NO_MORE_DOCS, binary.nextDoc());
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");

View File

@ -33,7 +33,6 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.Version;
@ -125,7 +124,7 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}

View File

@ -38,7 +38,6 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
@ -64,8 +63,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
abstract float[] randomVector(int dim);
abstract VectorEncoding vectorEncoding();
abstract Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction);

View File

@ -19,29 +19,27 @@ package org.apache.lucene.search;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
return new KnnByteVectorQuery(field, new BytesRef(floatToBytes(query)), k, queryFilter);
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, new BytesRef(floatToBytes(vec)), k, query);
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
}
@Override
float[] randomVector(int dim) {
BytesRef bytesRef = TestVectorUtil.randomVectorBytes(dim);
float[] v = new float[bytesRef.length];
byte[] b = TestVectorUtil.randomVectorBytes(dim);
float[] v = new float[b.length];
int vi = 0;
for (int i = bytesRef.offset; i < v.length; i++) {
v[vi++] = bytesRef.bytes[i];
for (int i = 0; i < v.length; i++) {
v[vi++] = b[i];
}
return v;
}
@ -49,13 +47,12 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override
Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnByteVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction);
}
@Override
Field getKnnVectorField(String name, float[] vector) {
return new KnnByteVectorField(
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}
private static byte[] floatToBytes(float[] query) {
@ -75,22 +72,14 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
public void testGetTarget() {
byte[] queryVectorBytes = floatToBytes(new float[] {0, 1});
BytesRef targetQueryVector = new BytesRef(queryVectorBytes);
KnnByteVectorQuery q1 = new KnnByteVectorQuery("f1", targetQueryVector, 10);
assertEquals(targetQueryVector, q1.getTargetCopy());
assertFalse(targetQueryVector == q1.getTargetCopy());
assertFalse(targetQueryVector.bytes == q1.getTargetCopy().bytes);
}
@Override
VectorEncoding vectorEncoding() {
return VectorEncoding.BYTE;
KnnByteVectorQuery q1 = new KnnByteVectorQuery("f1", queryVectorBytes, 10);
assertArrayEquals(queryVectorBytes, q1.getTargetCopy());
assertNotSame(queryVectorBytes, q1.getTargetCopy());
}
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, BytesRef target, int k, Query filter) {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, target, k, filter);
}

View File

@ -28,7 +28,6 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestVectorUtil;
@ -74,11 +73,6 @@ public class TestKnnVectorQuery extends BaseKnnVectorQueryTestCase {
assertNotEquals(queryVector, q1.getTargetCopy());
}
@Override
VectorEncoding vectorEncoding() {
return VectorEncoding.FLOAT32;
}
public void testScoreNegativeDotProduct() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {

View File

@ -33,7 +33,6 @@ import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
public class TestVectorScorer extends LuceneTestCase {
@ -49,7 +48,7 @@ public class TestVectorScorer extends LuceneTestCase {
final VectorScorer vectorScorer;
switch (encoding) {
case BYTE:
vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
vectorScorer = VectorScorer.create(context, fieldInfo, new byte[] {1, 2});
break;
case FLOAT32:
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
@ -76,9 +75,9 @@ public class TestVectorScorer extends LuceneTestCase {
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
if (encoding == VectorEncoding.BYTE) {
BytesRef v = new BytesRef(new byte[contents[i].length]);
byte[] v = new byte[contents[i].length];
for (int j = 0; j < v.length; j++) {
v.bytes[j] = (byte) contents[i][j];
v[j] = (byte) contents[i][j];
}
doc.add(new KnnByteVectorField(field, v, EUCLIDEAN));
} else {

View File

@ -131,19 +131,19 @@ public class TestVectorUtil extends LuceneTestCase {
return u;
}
private static BytesRef negative(BytesRef v) {
BytesRef u = new BytesRef(new byte[v.length]);
private static byte[] negative(byte[] v) {
byte[] u = new byte[v.length];
for (int i = 0; i < v.length; i++) {
// what is (byte) -(-128)? 127?
u.bytes[i] = (byte) -v.bytes[i];
u[i] = (byte) -v[i];
}
return u;
}
private static float l2(BytesRef v) {
private static float l2(byte[] v) {
float l2 = 0;
for (int i = v.offset; i < v.offset + v.length; i++) {
l2 += v.bytes[i] * v.bytes[i];
for (int i = 0; i < v.length; i++) {
l2 += v[i] * v[i];
}
return l2;
}
@ -161,7 +161,7 @@ public class TestVectorUtil extends LuceneTestCase {
return v;
}
private static BytesRef randomVectorBytes() {
private static byte[] randomVectorBytes() {
BytesRef v = TestUtil.randomBinaryTerm(random(), TestUtil.nextInt(random(), 1, 100));
// clip at -127 to avoid overflow
for (int i = v.offset; i < v.offset + v.length; i++) {
@ -169,10 +169,11 @@ public class TestVectorUtil extends LuceneTestCase {
v.bytes[i] = -127;
}
}
return v;
assert v.offset == 0;
return v.bytes;
}
public static BytesRef randomVectorBytes(int dim) {
public static byte[] randomVectorBytes(int dim) {
BytesRef v = TestUtil.randomBinaryTerm(random(), dim);
// clip at -127 to avoid overflow
for (int i = v.offset; i < v.offset + v.length; i++) {
@ -180,22 +181,22 @@ public class TestVectorUtil extends LuceneTestCase {
v.bytes[i] = -127;
}
}
return v;
return v.bytes;
}
public void testBasicDotProductBytes() {
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
byte[] a = new byte[] {1, 2, 3};
byte[] b = new byte[] {-10, 0, 5};
assertEquals(5, VectorUtil.dotProduct(a, b), 0);
float denom = a.length * (1 << 15);
assertEquals(0.5 + 5 / denom, VectorUtil.dotProductScore(a, b), DELTA);
// dot product 0 maps to dotProductScore 0.5
BytesRef zero = new BytesRef(new byte[] {0, 0, 0});
byte[] zero = new byte[] {0, 0, 0};
assertEquals(0.5, VectorUtil.dotProductScore(a, zero), DELTA);
BytesRef min = new BytesRef(new byte[] {-128, -128});
BytesRef max = new BytesRef(new byte[] {127, 127});
byte[] min = new byte[] {-128, -128};
byte[] max = new byte[] {127, 127};
// minimum dot product score is not quite zero because 127 < 128
assertEquals(0.0039, VectorUtil.dotProductScore(min, max), DELTA);
@ -205,55 +206,48 @@ public class TestVectorUtil extends LuceneTestCase {
public void testSelfDotProductBytes() {
// the dot product of a vector with itself is equal to the sum of the squares of its components
BytesRef v = randomVectorBytes();
byte[] v = randomVectorBytes();
assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
}
public void testOrthogonalDotProductBytes() {
// the dot product of two perpendicular vectors is 0
byte[] v = new byte[4];
v[0] = (byte) random().nextInt(100);
v[1] = (byte) random().nextInt(100);
v[2] = v[1];
v[3] = (byte) -v[0];
// also test computing using BytesRef with nonzero offset
assertEquals(0, VectorUtil.dotProduct(new BytesRef(v, 0, 2), new BytesRef(v, 2, 2)), DELTA);
byte[] a = new byte[2];
a[0] = (byte) random().nextInt(100);
a[1] = (byte) random().nextInt(100);
byte[] b = new byte[2];
b[0] = a[1];
b[1] = (byte) -a[0];
assertEquals(0, VectorUtil.dotProduct(a, b), DELTA);
}
public void testSelfSquareDistanceBytes() {
// the l2 distance of a vector with itself is zero
BytesRef v = randomVectorBytes();
byte[] v = randomVectorBytes();
assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
}
public void testBasicSquareDistanceBytes() {
assertEquals(
12,
VectorUtil.squareDistance(
new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-1, 0, 5})),
0);
assertEquals(12, VectorUtil.squareDistance(new byte[] {1, 2, 3}, new byte[] {-1, 0, 5}), 0);
}
public void testRandomSquareDistanceBytes() {
// the square distance of a vector with its inverse is equal to four times the sum of squares of
// its components
BytesRef v = randomVectorBytes();
BytesRef u = negative(v);
byte[] v = randomVectorBytes();
byte[] u = negative(v);
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
}
public void testBasicCosineBytes() {
assertEquals(
0.11952f,
VectorUtil.cosine(new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-10, 0, 5})),
DELTA);
assertEquals(0.11952f, VectorUtil.cosine(new byte[] {1, 2, 3}, new byte[] {-10, 0, 5}), DELTA);
}
public void testSelfCosineBytes() {
// the dot product of a vector with itself is always equal to 1
BytesRef v = randomVectorBytes();
byte[] v = randomVectorBytes();
// ensure the vector is non-zero so that cosine is defined
v.bytes[0] = (byte) (random().nextInt(126) + 1);
v[0] = (byte) (random().nextInt(126) + 1);
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
}

View File

@ -61,7 +61,6 @@ import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
@ -293,9 +292,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
NeighborQueue nn =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(BytesRef) getTargetVector(),
(byte[]) getTargetVector(),
10,
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
@ -346,9 +345,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
NeighborQueue nn =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(BytesRef) getTargetVector(),
(byte[]) getTargetVector(),
10,
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
@ -405,9 +404,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
acceptOrds,
Integer.MAX_VALUE);
case BYTE -> HnswGraphSearcher.search(
(BytesRef) getTargetVector(),
(byte[]) getTargetVector(),
numAccepted,
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
@ -446,9 +445,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
createRandomAcceptOrds(0, nDoc),
visitedLimit);
case BYTE -> HnswGraphSearcher.search(
(BytesRef) getTargetVector(),
(byte[]) getTargetVector(),
topK,
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
@ -663,9 +662,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
actual =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(BytesRef) query,
(byte[]) query,
100,
(RandomAccessVectorValues<BytesRef>) vectors,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
@ -689,9 +688,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
if (getVectorEncoding() == VectorEncoding.BYTE) {
assert query instanceof BytesRef;
assert query instanceof byte[];
expected.add(
j, similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(j)));
j, similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(j)));
} else {
assert query instanceof float[];
expected.add(
@ -789,17 +788,17 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
implements RandomAccessVectorValues<byte[]> {
private final int size;
private final float[] value;
private final BytesRef bValue;
private final byte[] bValue;
int doc = -1;
CircularByteVectorValues(int size) {
this.size = size;
value = new float[2];
bValue = new BytesRef(new byte[2]);
bValue = new byte[2];
}
@Override
@ -818,7 +817,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
@Override
public BytesRef vectorValue() {
public byte[] vectorValue() {
return vectorValue(doc);
}
@ -843,10 +842,10 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
@Override
public BytesRef vectorValue(int ord) {
public byte[] vectorValue(int ord) {
unitVector2d(ord / (double) size, value);
for (int i = 0; i < value.length; i++) {
bValue.bytes[i] = (byte) (value[i] * 127);
bValue[i] = (byte) (value[i] * 127);
}
return bValue;
}
@ -881,21 +880,17 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
break;
}
switch (getVectorEncoding()) {
case BYTE:
assertArrayEquals(
"vectors do not match for doc=" + uDoc,
((BytesRef) u.vectorValue()).bytes,
((BytesRef) v.vectorValue()).bytes);
break;
case FLOAT32:
assertArrayEquals(
"vectors do not match for doc=" + uDoc,
(float[]) u.vectorValue(),
(float[]) v.vectorValue(),
1e-4f);
break;
default:
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
case BYTE -> assertArrayEquals(
"vectors do not match for doc=" + uDoc,
(byte[]) u.vectorValue(),
(byte[]) v.vectorValue());
case FLOAT32 -> assertArrayEquals(
"vectors do not match for doc=" + uDoc,
(float[]) u.vectorValue(),
(float[]) v.vectorValue(),
1e-4f);
default -> throw new IllegalArgumentException(
"unknown vector encoding: " + getVectorEncoding());
}
}
}

View File

@ -71,9 +71,7 @@ import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
@ -516,12 +514,10 @@ public class KnnGraphTester {
private static class VectorReaderByte extends VectorReader {
private final byte[] scratch;
private final BytesRef bytesRef;
VectorReaderByte(FileChannel input, int dim, int bufferSize) {
super(input, dim, bufferSize);
scratch = new byte[dim];
bytesRef = new BytesRef(scratch);
}
@Override
@ -534,10 +530,10 @@ public class KnnGraphTester {
return target;
}
BytesRef nextBytes() throws IOException {
byte[] nextBytes() throws IOException {
readNext();
bytes.get(scratch);
return bytesRef;
return scratch;
}
}
@ -750,40 +746,7 @@ public class KnnGraphTester {
System.exit(1);
}
static class NeighborArraySorter extends IntroSorter {
private final int[] node;
private final float[] score;
NeighborArraySorter(NeighborArray neighbors) {
node = neighbors.node;
score = neighbors.score;
}
int pivot;
@Override
protected void swap(int i, int j) {
int tmpNode = node[i];
float tmpScore = score[i];
node[i] = node[j];
score[i] = score[j];
node[j] = tmpNode;
score[j] = tmpScore;
}
@Override
protected void setPivot(int i) {
pivot = i;
}
@Override
protected int comparePivot(int j) {
return Float.compare(score[pivot], score[j]);
}
}
private static class BitSetQuery extends Query {
private final FixedBitSet docs;
private final int cardinality;

View File

@ -19,20 +19,16 @@ package org.apache.lucene.util.hnsw;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
class MockByteVectorValues extends AbstractMockVectorValues<byte[]> {
private final byte[] scratch;
static MockByteVectorValues fromValues(byte[][] byteValues) {
int dimension = byteValues[0].length;
BytesRef[] values = new BytesRef[byteValues.length];
for (int i = 0; i < byteValues.length; i++) {
values[i] = byteValues[i] == null ? null : new BytesRef(byteValues[i]);
}
BytesRef[] denseValues = new BytesRef[values.length];
static MockByteVectorValues fromValues(byte[][] values) {
int dimension = values[0].length;
int maxDoc = values.length;
byte[][] denseValues = new byte[maxDoc][];
int count = 0;
for (int i = 0; i < byteValues.length; i++) {
for (int i = 0; i < maxDoc; i++) {
if (values[i] != null) {
denseValues[count++] = values[i];
}
@ -40,7 +36,7 @@ class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
return new MockByteVectorValues(values, dimension, denseValues, count);
}
MockByteVectorValues(BytesRef[] values, int dimension, BytesRef[] denseValues, int numVectors) {
MockByteVectorValues(byte[][] values, int dimension, byte[][] denseValues, int numVectors) {
super(values, dimension, denseValues, numVectors);
scratch = new byte[dimension];
}
@ -55,7 +51,7 @@ class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
}
@Override
public BytesRef vectorValue() {
public byte[] vectorValue() {
if (LuceneTestCase.random().nextBoolean()) {
return values[pos];
} else {
@ -63,8 +59,8 @@ class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
// twice in a
// single computation.
System.arraycopy(values[pos].bytes, values[pos].offset, scratch, 0, dimension);
return new BytesRef(scratch);
System.arraycopy(values[pos], 0, scratch, 0, dimension);
return scratch;
}
}
}

View File

@ -30,11 +30,10 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.junit.Before;
/** Tests HNSW KNN graphs */
public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
@Before
public void setup() {
@ -47,17 +46,17 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
}
@Override
Query knnQuery(String field, BytesRef vector, int k) {
Query knnQuery(String field, byte[] vector, int k) {
return new KnnByteVectorQuery(field, vector, k);
}
@Override
BytesRef randomVector(int dim) {
return new BytesRef(randomVector8(random(), dim));
byte[] randomVector(int dim) {
return randomVector8(random(), dim);
}
@Override
AbstractMockVectorValues<BytesRef> vectorValues(int size, int dimension) {
AbstractMockVectorValues<byte[]> vectorValues(int size, int dimension) {
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
}
@ -66,7 +65,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
}
@Override
AbstractMockVectorValues<BytesRef> vectorValues(float[][] values) {
AbstractMockVectorValues<byte[]> vectorValues(float[][] values) {
byte[][] bValues = new byte[values.length][];
// The case when all floats fit within a byte already.
boolean scaleSimple = fitsInByte(values[0][0]);
@ -87,32 +86,30 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
}
@Override
AbstractMockVectorValues<BytesRef> vectorValues(LeafReader reader, String fieldName)
AbstractMockVectorValues<byte[]> vectorValues(LeafReader reader, String fieldName)
throws IOException {
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
byte[][] vectors = new byte[reader.maxDoc()][];
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
vectors[vectorValues.docID()] =
ArrayUtil.copyOfSubArray(
vectorValues.vectorValue().bytes,
vectorValues.vectorValue().offset,
vectorValues.vectorValue().offset + vectorValues.vectorValue().length);
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
}
return MockByteVectorValues.fromValues(vectors);
}
@Override
Field knnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
Field knnVectorField(String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnByteVectorField(name, vector, similarityFunction);
}
@Override
RandomAccessVectorValues<BytesRef> circularVectorValues(int nDoc) {
RandomAccessVectorValues<byte[]> circularVectorValues(int nDoc) {
return new CircularByteVectorValues(nDoc);
}
@Override
BytesRef getTargetVector() {
return new BytesRef(new byte[] {1, 0});
byte[] getTargetVector() {
return new byte[] {1, 0};
}
}

View File

@ -42,7 +42,6 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
/**
@ -179,7 +178,7 @@ public class TermVectorLeafReader extends LeafReader {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}

View File

@ -1408,7 +1408,7 @@ public class MemoryIndex {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}

View File

@ -34,7 +34,6 @@ import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
@ -153,7 +152,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null

View File

@ -81,8 +81,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
@Override
protected void addRandomFields(Document doc) {
switch (vectorEncoding) {
case BYTE -> doc.add(
new KnnByteVectorField("v2", new BytesRef(randomVector8(30)), similarityFunction));
case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction));
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
}
}
@ -632,9 +631,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
switch (fieldVectorEncodings[field]) {
case BYTE -> {
byte[] b = randomVector8(fieldDims[field]);
doc.add(
new KnnByteVectorField(
fieldName, new BytesRef(b), fieldSimilarityFunctions[field]));
doc.add(new KnnByteVectorField(fieldName, b, fieldSimilarityFunctions[field]));
fieldTotals[field] += b[0];
}
case FLOAT32 -> {
@ -660,7 +657,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (byteVectorValues != null) {
docCount += byteVectorValues.size();
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += byteVectorValues.vectorValue().bytes[0];
checksum += byteVectorValues.vectorValue()[0];
}
}
}
@ -766,10 +763,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
add(iw, fieldName, 1, 1, new BytesRef(new byte[] {-1, 0}));
add(iw, fieldName, 4, 4, new BytesRef(new byte[] {0, 1}));
add(iw, fieldName, 3, 3, (BytesRef) null);
add(iw, fieldName, 2, 2, new BytesRef(new byte[] {1, 0}));
add(iw, fieldName, 1, 1, new byte[] {-1, 0});
add(iw, fieldName, 4, 4, new byte[] {0, 1});
add(iw, fieldName, 3, 3, (byte[]) null);
add(iw, fieldName, 2, 2, new byte[] {1, 0});
iw.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(iw)) {
LeafReader leaf = getOnlyLeafReader(reader);
@ -779,11 +776,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size());
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(-1, vectorValues.vectorValue().bytes[0], 0);
assertEquals(-1, vectorValues.vectorValue()[0], 0);
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue().bytes[0], 0);
assertEquals(1, vectorValues.vectorValue()[0], 0);
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue().bytes[0], 0);
assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
}
}
@ -928,8 +925,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
BytesRef scratch = new BytesRef(dimension);
scratch.length = dimension;
byte[] scratch = new byte[dimension];
int numValues = 0;
BytesRef[] values = new BytesRef[numDoc];
for (int i = 0; i < numDoc; i++) {
@ -940,10 +936,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i].bytes, 0, scratch.bytes, 0, dimension);
System.arraycopy(values[i].bytes, 0, scratch, 0, dimension);
add(iw, fieldName, i, scratch, similarityFunction);
} else {
add(iw, fieldName, i, values[i], similarityFunction);
BytesRef value = values[i];
add(iw, fieldName, i, value == null ? null : value.bytes, similarityFunction);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
@ -971,12 +968,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
StoredFields storedFields = ctx.reader().storedFields();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
BytesRef v = vectorValues.vectorValue();
byte[] v = vectorValues.vectorValue();
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
assertEquals(idString, 0, values[id].compareTo(v));
assertEquals(idString, 0, values[id].compareTo(new BytesRef(v)));
++valueCount;
} else {
++numDeletes;
@ -1141,12 +1138,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
private void add(
IndexWriter iw, String field, int id, BytesRef vector, VectorSimilarityFunction similarity)
IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity)
throws IOException {
add(iw, field, id, random().nextInt(100), vector, similarity);
}
private void add(IndexWriter iw, String field, int id, int sortKey, BytesRef vector)
private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector)
throws IOException {
add(iw, field, id, sortKey, vector, VectorSimilarityFunction.EUCLIDEAN);
}
@ -1156,7 +1153,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
String field,
int id,
int sortKey,
BytesRef vector,
byte[] vector,
VectorSimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
@ -1319,7 +1316,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
case BYTE -> {
byte[] b = randomVector8(dim);
fieldValuesCheckSum += b[0];
doc.add(new KnnByteVectorField("knn_vector", new BytesRef(b), similarityFunction));
doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction));
}
case FLOAT32 -> {
float[] v = randomVector(dim);
@ -1349,7 +1346,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
docCount += byteVectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += byteVectorValues.vectorValue().bytes[0];
checksum += byteVectorValues.vectorValue()[0];
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
sumDocIds += Integer.parseInt(doc.get("id"));
}

View File

@ -42,7 +42,6 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
* This is a hack to make index sorting fast, with a {@link LeafReader} that always returns merge
@ -236,7 +235,7 @@ class MergeReaderWrapper extends LeafReader {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}

View File

@ -55,7 +55,6 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
import org.junit.Assert;
@ -243,7 +242,7 @@ public class QueryUtils {
@Override
public TopDocs searchNearestVectors(
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
return null;
}