mirror of https://github.com/apache/lucene.git
Replace BytesRef usages in byte vectors API with byte[] (#12102)
The main classes involved are ByteVectorValues, KnnByteVectorField and KnnByteVectorQuery. It becomes quite natural to simplify things further and use byte[] in the following methods too: ByteVectorValues#vectorValue, KnnVectorReader#search, LeafReader#searchNearestVectors, HNSWGraphSearcher#search, VectorSimilarityFunction#compare, VectorUtil#cosine, VectorUtil#squareDistance, VectorUtil#dotProduct, VectorUtil#dotProductScore
This commit is contained in:
parent
f8ee852696
commit
4594400216
|
@ -41,7 +41,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
|
|||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -279,7 +278,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
|
|
@ -42,7 +42,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
|
|||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -270,7 +269,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
|
|
@ -41,7 +41,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
|
|||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -266,7 +265,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
|
|
@ -41,7 +41,6 @@ import org.apache.lucene.store.ChecksumIndexInput;
|
|||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -302,7 +301,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
|
|
|
@ -25,18 +25,17 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
implements RandomAccessVectorValues<byte[]> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
protected final IndexInput slice;
|
||||
protected final BytesRef binaryValue;
|
||||
protected final byte[] binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
|
||||
|
@ -46,7 +45,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
binaryValue = byteBuffer.array();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -60,7 +59,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
@ -99,7 +98,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
slice.seek((long) doc * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
return binaryValue;
|
||||
|
@ -125,7 +124,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
public RandomAccessVectorValues<byte[]> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -171,7 +170,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
slice.seek((long) (disi.index()) * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
|
||||
return binaryValue;
|
||||
|
@ -194,7 +193,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
public RandomAccessVectorValues<byte[]> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -241,7 +240,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
@ -261,12 +260,12 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
public RandomAccessVectorValues<byte[]> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -185,7 +185,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// write vector
|
||||
float[] vectorValue = vectors.vectorValue();
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docIds[count] = docV;
|
||||
}
|
||||
|
||||
|
|
|
@ -179,7 +179,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// write vector
|
||||
float[] vectorValue = vectors.vectorValue();
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
|
|
@ -187,7 +187,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// write vector
|
||||
float[] vectorValue = vectors.vectorValue();
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
|
|
@ -43,7 +43,6 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
@ -197,17 +196,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (Object v : fieldData.vectors) {
|
||||
buffer.asFloatBuffer().put((float[]) v);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
||||
for (Object v : fieldData.vectors) {
|
||||
BytesRef vector = (BytesRef) v;
|
||||
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
|
||||
byte[] vector = (byte[]) v;
|
||||
vectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -267,11 +265,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (int ordinal : ordMap) {
|
||||
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
@ -279,8 +276,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
for (int ordinal : ordMap) {
|
||||
BytesRef vector = (BytesRef) fieldData.vectors.get(ordinal);
|
||||
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
|
||||
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
||||
vectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
@ -423,7 +420,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
|
||||
HnswGraphBuilder<byte[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
|
@ -596,9 +593,9 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = byteVectorValues.vectorValue();
|
||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
@ -619,7 +616,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
// write vector
|
||||
float[] vectorValue = floatVectorValues.vectorValue();
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), 0, binaryVector.limit());
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
@ -644,10 +641,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
throws IOException {
|
||||
int dim = fieldInfo.getVectorDimension();
|
||||
return switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
@Override
|
||||
public BytesRef copyValue(BytesRef value) {
|
||||
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
|
||||
public byte[] copyValue(byte[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
|
|
|
@ -222,7 +222,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
ByteVectorValues values = getByteVectorValues(field);
|
||||
if (target.length != values.dimension()) {
|
||||
|
@ -250,7 +250,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
break;
|
||||
}
|
||||
|
||||
BytesRef vector = values.vectorValue();
|
||||
byte[] vector = values.vectorValue();
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||
numVisited++;
|
||||
|
@ -458,9 +458,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() {
|
||||
public byte[] vectorValue() {
|
||||
binaryValue.bytes = values[curOrd];
|
||||
return binaryValue;
|
||||
return binaryValue.bytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
|
@ -97,7 +96,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
@ -192,7 +191,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.NamedSPILoader;
|
||||
|
||||
/**
|
||||
|
@ -112,7 +111,7 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.lucene.search.TopDocs;
|
|||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Reads vectors from an index. */
|
||||
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
|
@ -117,7 +116,7 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
/**
|
||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||
* that called {@link #getMergeInstance()}.
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Writes vectors to an index. */
|
||||
public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||
|
@ -49,8 +48,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
KnnFieldVectorsWriter<BytesRef> byteWriter =
|
||||
(KnnFieldVectorsWriter<BytesRef>) addField(fieldInfo);
|
||||
KnnFieldVectorsWriter<byte[]> byteWriter =
|
||||
(KnnFieldVectorsWriter<byte[]>) addField(fieldInfo);
|
||||
ByteVectorValues mergedBytes =
|
||||
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedBytes.nextDoc();
|
||||
|
@ -262,7 +261,7 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
@ -310,7 +309,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
|
|
|
@ -186,17 +186,16 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (Object v : fieldData.vectors) {
|
||||
buffer.asFloatBuffer().put((float[]) v);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
||||
for (Object v : fieldData.vectors) {
|
||||
BytesRef vector = (BytesRef) v;
|
||||
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
|
||||
byte[] vector = (byte[]) v;
|
||||
vectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,11 +257,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (int ordinal : ordMap) {
|
||||
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
@ -270,8 +268,8 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
for (int ordinal : ordMap) {
|
||||
BytesRef vector = (BytesRef) fieldData.vectors.get(ordinal);
|
||||
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
|
||||
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
||||
vectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
@ -435,7 +433,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
|
||||
HnswGraphBuilder<byte[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
|
@ -646,9 +644,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = byteVectorValues.vectorValue();
|
||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
@ -669,7 +667,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
// write vector
|
||||
float[] value = floatVectorValues.vectorValue();
|
||||
buffer.asFloatBuffer().put(value);
|
||||
output.writeBytes(buffer.array(), 0, buffer.limit());
|
||||
output.writeBytes(buffer.array(), buffer.limit());
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
@ -694,11 +692,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
throws IOException {
|
||||
int dim = fieldInfo.getVectorDimension();
|
||||
return switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
@Override
|
||||
public BytesRef copyValue(BytesRef value) {
|
||||
return new BytesRef(
|
||||
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
|
||||
public byte[] copyValue(byte[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
|
|
|
@ -25,18 +25,17 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
implements RandomAccessVectorValues<byte[]> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
protected final IndexInput slice;
|
||||
protected final BytesRef binaryValue;
|
||||
protected final byte[] binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
|
||||
|
@ -46,7 +45,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
binaryValue = byteBuffer.array();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -60,7 +59,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
@ -99,7 +98,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
slice.seek((long) doc * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
return binaryValue;
|
||||
|
@ -125,7 +124,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
public RandomAccessVectorValues<byte[]> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -171,7 +170,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
slice.seek((long) (disi.index()) * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
|
||||
return binaryValue;
|
||||
|
@ -194,7 +193,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
public RandomAccessVectorValues<byte[]> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -241,7 +240,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
@ -261,12 +260,12 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
public RandomAccessVectorValues<byte[]> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,6 @@ import org.apache.lucene.index.Sorter;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
||||
/**
|
||||
|
@ -273,7 +272,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
|
||||
|
@ -38,7 +37,7 @@ import org.apache.lucene.util.BytesRef;
|
|||
*/
|
||||
public class KnnByteVectorField extends Field {
|
||||
|
||||
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
||||
private static FieldType createType(byte[] v, VectorSimilarityFunction similarityFunction) {
|
||||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
|
@ -67,7 +66,7 @@ public class KnnByteVectorField extends Field {
|
|||
* @param k The number of nearest neighbors to gather
|
||||
* @return A new vector query
|
||||
*/
|
||||
public static Query newVectorQuery(String field, BytesRef queryVector, int k) {
|
||||
public static Query newVectorQuery(String field, byte[] queryVector, int k) {
|
||||
return new KnnByteVectorQuery(field, queryVector, k);
|
||||
}
|
||||
|
||||
|
@ -99,7 +98,7 @@ public class KnnByteVectorField extends Field {
|
|||
* dimension > 1024.
|
||||
*/
|
||||
public KnnByteVectorField(
|
||||
String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
super(name, createType(vector, similarityFunction));
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
@ -114,7 +113,7 @@ public class KnnByteVectorField extends Field {
|
|||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnByteVectorField(String name, BytesRef vector) {
|
||||
public KnnByteVectorField(String name, byte[] vector) {
|
||||
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
||||
|
@ -128,7 +127,7 @@ public class KnnByteVectorField extends Field {
|
|||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnByteVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||
public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
|
@ -141,8 +140,8 @@ public class KnnByteVectorField extends Field {
|
|||
}
|
||||
|
||||
/** Return the vector value of this field */
|
||||
public BytesRef vectorValue() {
|
||||
return (BytesRef) fieldsData;
|
||||
public byte[] vectorValue() {
|
||||
return (byte[]) fieldsData;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -150,7 +149,7 @@ public class KnnByteVectorField extends Field {
|
|||
*
|
||||
* @param value the value to set; must not be null, and length must match the field type
|
||||
*/
|
||||
public void setVectorValue(BytesRef value) {
|
||||
public void setVectorValue(byte[] value) {
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException("value must not be null");
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.lucene.index;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* This class provides access to per-document floating point vector values indexed as {@link
|
||||
|
@ -57,5 +56,5 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
|
|||
*
|
||||
* @return the vector value
|
||||
*/
|
||||
public abstract BytesRef vectorValue() throws IOException;
|
||||
public abstract byte[] vectorValue() throws IOException;
|
||||
}
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.lucene.codecs.StoredFieldsReader;
|
|||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** LeafReader implemented by codec APIs. */
|
||||
public abstract class CodecReader extends LeafReader {
|
||||
|
@ -257,7 +256,7 @@ public abstract class CodecReader extends LeafReader {
|
|||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.lucene.index;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
abstract class DocValuesLeafReader extends LeafReader {
|
||||
@Override
|
||||
|
@ -66,7 +65,7 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -508,7 +508,7 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValues.vectorValue();
|
||||
}
|
||||
|
||||
|
|
|
@ -364,7 +364,7 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
|
|
|
@ -961,7 +961,7 @@ final class IndexingChain implements Accountable {
|
|||
int docID, PerField pf, VectorEncoding vectorEncoding, IndexableField field)
|
||||
throws IOException {
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> ((KnnFieldVectorsWriter<BytesRef>) pf.knnFieldVectorsWriter)
|
||||
case BYTE -> ((KnnFieldVectorsWriter<byte[]>) pf.knnFieldVectorsWriter)
|
||||
.addValue(docID, ((KnnByteVectorField) field).vectorValue());
|
||||
case FLOAT32 -> ((KnnFieldVectorsWriter<float[]>) pf.knnFieldVectorsWriter)
|
||||
.addValue(docID, ((KnnVectorField) field).vectorValue());
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* {@code LeafReader} is an abstract class, providing an interface for accessing an index. Search of
|
||||
|
@ -270,7 +269,7 @@ public abstract class LeafReader extends IndexReader {
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public abstract TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the {@link FieldInfos} describing all fields in this reader.
|
||||
|
|
|
@ -29,7 +29,6 @@ import java.util.TreeMap;
|
|||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
||||
/**
|
||||
|
@ -428,7 +427,7 @@ public class ParallelLeafReader extends LeafReader {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String fieldName, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
String fieldName, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.codecs.StoredFieldsReader;
|
|||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Wraps arbitrary readers for merging. Note that this can cause slow and memory-intensive merges.
|
||||
|
@ -180,7 +179,7 @@ public final class SlowCodecReaderWrapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.apache.lucene.codecs.TermVectorsReader;
|
|||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
@ -284,7 +285,9 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(doc);
|
||||
docsWithField.set(newDocID);
|
||||
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.vectorValue());
|
||||
binaryVectors[newDocID] =
|
||||
new BytesRef(
|
||||
ArrayUtil.copyOfSubArray(delegate.vectorValue(), 0, delegate.vectorValue().length));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,8 +302,8 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
return binaryVectors[docId];
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return binaryVectors[docId].bytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -505,8 +508,7 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,6 @@ import static org.apache.lucene.util.VectorUtil.dotProduct;
|
|||
import static org.apache.lucene.util.VectorUtil.dotProductScore;
|
||||
import static org.apache.lucene.util.VectorUtil.squareDistance;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Vector similarity function; used in search to return top K most similar vectors to a target
|
||||
* vector. This is a label describing the method used during indexing and searching of the vectors
|
||||
|
@ -38,7 +36,7 @@ public enum VectorSimilarityFunction {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float compare(BytesRef v1, BytesRef v2) {
|
||||
public float compare(byte[] v1, byte[] v2) {
|
||||
return 1 / (1 + squareDistance(v1, v2));
|
||||
}
|
||||
},
|
||||
|
@ -57,7 +55,7 @@ public enum VectorSimilarityFunction {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float compare(BytesRef v1, BytesRef v2) {
|
||||
public float compare(byte[] v1, byte[] v2) {
|
||||
return dotProductScore(v1, v2);
|
||||
}
|
||||
},
|
||||
|
@ -75,7 +73,7 @@ public enum VectorSimilarityFunction {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float compare(BytesRef v1, BytesRef v2) {
|
||||
public float compare(byte[] v1, byte[] v2) {
|
||||
return (1 + cosine(v1, v2)) / 2;
|
||||
}
|
||||
};
|
||||
|
@ -99,5 +97,5 @@ public enum VectorSimilarityFunction {
|
|||
* @param v2 another vector, of the same dimension
|
||||
* @return the value of the similarity function applied to the two vectors
|
||||
*/
|
||||
public abstract float compare(BytesRef v1, BytesRef v2);
|
||||
public abstract float compare(byte[] v1, byte[] v2);
|
||||
}
|
||||
|
|
|
@ -17,18 +17,19 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search(String, BytesRef, int, Bits, int)} to perform nearest
|
||||
* neighbour search.
|
||||
* Uses {@link KnnVectorsReader#search(String, byte[], int, Bits, int)} to perform nearest neighbour
|
||||
* search.
|
||||
*
|
||||
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
|
||||
* executes the filter for each leaf, then chooses a strategy dynamically:
|
||||
|
@ -43,7 +44,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
|
||||
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
private final BytesRef target;
|
||||
private final byte[] target;
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
|
@ -54,7 +55,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
* @param k the number of documents to find
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnByteVectorQuery(String field, BytesRef target, int k) {
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k) {
|
||||
this(field, target, k, null);
|
||||
}
|
||||
|
||||
|
@ -68,7 +69,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
* @param filter a filter applied before the vector search
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnByteVectorQuery(String field, BytesRef target, int k, Query filter) {
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
super(field, k, filter);
|
||||
this.target = target;
|
||||
}
|
||||
|
@ -91,14 +92,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return getClass().getSimpleName()
|
||||
+ ":"
|
||||
+ this.field
|
||||
+ "["
|
||||
+ target.bytes[target.offset]
|
||||
+ ",...]["
|
||||
+ k
|
||||
+ "]";
|
||||
return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -106,18 +100,18 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
if (this == o) return true;
|
||||
if (super.equals(o) == false) return false;
|
||||
KnnByteVectorQuery that = (KnnByteVectorQuery) o;
|
||||
return Objects.equals(target, that.target);
|
||||
return Arrays.equals(target, that.target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), target);
|
||||
return Objects.hash(super.hashCode(), Arrays.hashCode(target));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the target query vector of the search. Each vector element is a byte.
|
||||
*/
|
||||
public BytesRef getTargetCopy() {
|
||||
return BytesRef.deepCopyOf(target);
|
||||
public byte[] getTargetCopy() {
|
||||
return ArrayUtil.copyOfSubArray(target, 0, target.length);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Computes the similarity score between a given query vector and different document vectors. This
|
||||
|
@ -46,7 +45,7 @@ abstract class VectorScorer {
|
|||
return new FloatVectorScorer(values, query, similarity);
|
||||
}
|
||||
|
||||
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
|
||||
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
|
||||
throws IOException {
|
||||
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
|
||||
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
|
@ -63,11 +62,11 @@ abstract class VectorScorer {
|
|||
abstract boolean advanceExact(int doc) throws IOException;
|
||||
|
||||
private static class ByteVectorScorer extends VectorScorer {
|
||||
private final BytesRef query;
|
||||
private final byte[] query;
|
||||
private final ByteVectorValues values;
|
||||
|
||||
protected ByteVectorScorer(
|
||||
ByteVectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
|
||||
ByteVectorValues values, byte[] query, VectorSimilarityFunction similarity) {
|
||||
super(similarity);
|
||||
this.values = values;
|
||||
this.query = query;
|
||||
|
|
|
@ -122,16 +122,15 @@ public final class VectorUtil {
|
|||
}
|
||||
|
||||
/** Returns the cosine similarity between the two vectors. */
|
||||
public static float cosine(BytesRef a, BytesRef b) {
|
||||
public static float cosine(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
int norm2 = 0;
|
||||
int aOffset = a.offset, bOffset = b.offset;
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
byte elem1 = a.bytes[aOffset++];
|
||||
byte elem2 = b.bytes[bOffset++];
|
||||
byte elem1 = a[i];
|
||||
byte elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
|
@ -182,12 +181,11 @@ public final class VectorUtil {
|
|||
}
|
||||
|
||||
/** Returns the sum of squared differences of the two vectors. */
|
||||
public static float squareDistance(BytesRef a, BytesRef b) {
|
||||
public static float squareDistance(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int squareSum = 0;
|
||||
int aOffset = a.offset, bOffset = b.offset;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
int diff = a.bytes[aOffset++] - b.bytes[bOffset++];
|
||||
int diff = a[i] - b[i];
|
||||
squareSum += diff * diff;
|
||||
}
|
||||
return squareSum;
|
||||
|
@ -251,12 +249,11 @@ public final class VectorUtil {
|
|||
* @param b bytes containing another vector, of the same dimension
|
||||
* @return the value of the dot product of the two vectors
|
||||
*/
|
||||
public static float dotProduct(BytesRef a, BytesRef b) {
|
||||
public static float dotProduct(byte[] a, byte[] b) {
|
||||
assert a.length == b.length;
|
||||
int total = 0;
|
||||
int aOffset = a.offset, bOffset = b.offset;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
total += a.bytes[aOffset++] * b.bytes[bOffset++];
|
||||
total += a[i] * b[i];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
@ -268,7 +265,7 @@ public final class VectorUtil {
|
|||
* @param b bytes containing another vector, of the same dimension
|
||||
* @return the value of the similarity function applied to the two vectors
|
||||
*/
|
||||
public static float dotProductScore(BytesRef a, BytesRef b) {
|
||||
public static float dotProductScore(byte[] a, byte[] b) {
|
||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||
float denom = (float) (a.length * (1 << 15));
|
||||
return 0.5f + dotProduct(a, b) / denom;
|
||||
|
|
|
@ -26,7 +26,6 @@ import java.util.SplittableRandom;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
|
@ -273,7 +272,7 @@ public final class HnswGraphBuilder<T> {
|
|||
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isDiverse((BytesRef) vectors.vectorValue(candidate), neighbors, score);
|
||||
case BYTE -> isDiverse((byte[]) vectors.vectorValue(candidate), neighbors, score);
|
||||
case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
|
||||
};
|
||||
}
|
||||
|
@ -291,12 +290,12 @@ public final class HnswGraphBuilder<T> {
|
|||
return true;
|
||||
}
|
||||
|
||||
private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float score)
|
||||
private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(
|
||||
candidate, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
candidate, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
|
@ -322,7 +321,7 @@ public final class HnswGraphBuilder<T> {
|
|||
int candidateNode = neighbors.node[candidateIndex];
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isWorstNonDiverse(
|
||||
candidateIndex, (BytesRef) vectors.vectorValue(candidateNode), neighbors);
|
||||
candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
|
||||
case FLOAT32 -> isWorstNonDiverse(
|
||||
candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
|
||||
};
|
||||
|
@ -344,12 +343,12 @@ public final class HnswGraphBuilder<T> {
|
|||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
|
||||
int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
|
||||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(
|
||||
candidateVector, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
// candidate node is too similar to node i given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return true;
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
|
||||
|
@ -135,9 +134,9 @@ public class HnswGraphSearcher<T> {
|
|||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
public static NeighborQueue search(
|
||||
BytesRef query,
|
||||
byte[] query,
|
||||
int topK,
|
||||
RandomAccessVectorValues<BytesRef> vectors,
|
||||
RandomAccessVectorValues<byte[]> vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graph,
|
||||
|
@ -151,7 +150,7 @@ public class HnswGraphSearcher<T> {
|
|||
+ " differs from field dimension: "
|
||||
+ vectors.dimension());
|
||||
}
|
||||
HnswGraphSearcher<BytesRef> graphSearcher =
|
||||
HnswGraphSearcher<byte[]> graphSearcher =
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
|
@ -281,7 +280,7 @@ public class HnswGraphSearcher<T> {
|
|||
|
||||
private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
return similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(ord));
|
||||
return similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(ord));
|
||||
} else {
|
||||
return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
|
||||
}
|
||||
|
|
|
@ -637,14 +637,17 @@ public class TestField extends LuceneTestCase {
|
|||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
BytesRef br = newBytesRef(new byte[5]);
|
||||
Field field = new KnnByteVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
|
||||
byte[] b = new byte[5];
|
||||
KnnByteVectorField field =
|
||||
new KnnByteVectorField("binary", b, VectorSimilarityFunction.EUCLIDEAN);
|
||||
assertNull(field.binaryValue());
|
||||
assertArrayEquals(b, field.vectorValue());
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
|
||||
float[] vector = new float[] {1, 2};
|
||||
Field field2 = new KnnVectorField("float", vector);
|
||||
assertEquals(br, field.binaryValue());
|
||||
assertNull(field2.binaryValue());
|
||||
doc.add(field);
|
||||
doc.add(field2);
|
||||
w.addDocument(doc);
|
||||
|
@ -653,7 +656,7 @@ public class TestField extends LuceneTestCase {
|
|||
assertEquals(1, binary.size());
|
||||
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
assertNotNull(binary.vectorValue());
|
||||
assertEquals(br, binary.vectorValue());
|
||||
assertArrayEquals(b, binary.vectorValue());
|
||||
assertEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
|
||||
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
|
||||
|
|
|
@ -33,7 +33,6 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.NamedThreadFactory;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
@ -125,7 +124,7 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,6 @@ import org.apache.lucene.index.LeafReader;
|
|||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
|
@ -64,8 +63,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
|
||||
abstract float[] randomVector(int dim);
|
||||
|
||||
abstract VectorEncoding vectorEncoding();
|
||||
|
||||
abstract Field getKnnVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction similarityFunction);
|
||||
|
||||
|
|
|
@ -19,29 +19,27 @@ package org.apache.lucene.search;
|
|||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
|
||||
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||
@Override
|
||||
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
|
||||
return new KnnByteVectorQuery(field, new BytesRef(floatToBytes(query)), k, queryFilter);
|
||||
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
|
||||
return new ThrowingKnnVectorQuery(field, new BytesRef(floatToBytes(vec)), k, query);
|
||||
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] randomVector(int dim) {
|
||||
BytesRef bytesRef = TestVectorUtil.randomVectorBytes(dim);
|
||||
float[] v = new float[bytesRef.length];
|
||||
byte[] b = TestVectorUtil.randomVectorBytes(dim);
|
||||
float[] v = new float[b.length];
|
||||
int vi = 0;
|
||||
for (int i = bytesRef.offset; i < v.length; i++) {
|
||||
v[vi++] = bytesRef.bytes[i];
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
v[vi++] = b[i];
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
@ -49,13 +47,12 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
@Override
|
||||
Field getKnnVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnByteVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
|
||||
return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field getKnnVectorField(String name, float[] vector) {
|
||||
return new KnnByteVectorField(
|
||||
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
|
||||
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
||||
private static byte[] floatToBytes(float[] query) {
|
||||
|
@ -75,22 +72,14 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
|
||||
public void testGetTarget() {
|
||||
byte[] queryVectorBytes = floatToBytes(new float[] {0, 1});
|
||||
BytesRef targetQueryVector = new BytesRef(queryVectorBytes);
|
||||
KnnByteVectorQuery q1 = new KnnByteVectorQuery("f1", targetQueryVector, 10);
|
||||
|
||||
assertEquals(targetQueryVector, q1.getTargetCopy());
|
||||
assertFalse(targetQueryVector == q1.getTargetCopy());
|
||||
assertFalse(targetQueryVector.bytes == q1.getTargetCopy().bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorEncoding vectorEncoding() {
|
||||
return VectorEncoding.BYTE;
|
||||
KnnByteVectorQuery q1 = new KnnByteVectorQuery("f1", queryVectorBytes, 10);
|
||||
assertArrayEquals(queryVectorBytes, q1.getTargetCopy());
|
||||
assertNotSame(queryVectorBytes, q1.getTargetCopy());
|
||||
}
|
||||
|
||||
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
|
||||
|
||||
public ThrowingKnnVectorQuery(String field, BytesRef target, int k, Query filter) {
|
||||
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
super(field, target, k, filter);
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ import org.apache.lucene.index.IndexReader;
|
|||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
|
@ -74,11 +73,6 @@ public class TestKnnVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
assertNotEquals(queryVector, q1.getTargetCopy());
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorEncoding vectorEncoding() {
|
||||
return VectorEncoding.FLOAT32;
|
||||
}
|
||||
|
||||
public void testScoreNegativeDotProduct() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
|
|
|
@ -33,7 +33,6 @@ import org.apache.lucene.index.VectorEncoding;
|
|||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
public class TestVectorScorer extends LuceneTestCase {
|
||||
|
||||
|
@ -49,7 +48,7 @@ public class TestVectorScorer extends LuceneTestCase {
|
|||
final VectorScorer vectorScorer;
|
||||
switch (encoding) {
|
||||
case BYTE:
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new byte[] {1, 2});
|
||||
break;
|
||||
case FLOAT32:
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
||||
|
@ -76,9 +75,9 @@ public class TestVectorScorer extends LuceneTestCase {
|
|||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
if (encoding == VectorEncoding.BYTE) {
|
||||
BytesRef v = new BytesRef(new byte[contents[i].length]);
|
||||
byte[] v = new byte[contents[i].length];
|
||||
for (int j = 0; j < v.length; j++) {
|
||||
v.bytes[j] = (byte) contents[i][j];
|
||||
v[j] = (byte) contents[i][j];
|
||||
}
|
||||
doc.add(new KnnByteVectorField(field, v, EUCLIDEAN));
|
||||
} else {
|
||||
|
|
|
@ -131,19 +131,19 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
return u;
|
||||
}
|
||||
|
||||
private static BytesRef negative(BytesRef v) {
|
||||
BytesRef u = new BytesRef(new byte[v.length]);
|
||||
private static byte[] negative(byte[] v) {
|
||||
byte[] u = new byte[v.length];
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
// what is (byte) -(-128)? 127?
|
||||
u.bytes[i] = (byte) -v.bytes[i];
|
||||
u[i] = (byte) -v[i];
|
||||
}
|
||||
return u;
|
||||
}
|
||||
|
||||
private static float l2(BytesRef v) {
|
||||
private static float l2(byte[] v) {
|
||||
float l2 = 0;
|
||||
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||
l2 += v.bytes[i] * v.bytes[i];
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
l2 += v[i] * v[i];
|
||||
}
|
||||
return l2;
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
return v;
|
||||
}
|
||||
|
||||
private static BytesRef randomVectorBytes() {
|
||||
private static byte[] randomVectorBytes() {
|
||||
BytesRef v = TestUtil.randomBinaryTerm(random(), TestUtil.nextInt(random(), 1, 100));
|
||||
// clip at -127 to avoid overflow
|
||||
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||
|
@ -169,10 +169,11 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
v.bytes[i] = -127;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
assert v.offset == 0;
|
||||
return v.bytes;
|
||||
}
|
||||
|
||||
public static BytesRef randomVectorBytes(int dim) {
|
||||
public static byte[] randomVectorBytes(int dim) {
|
||||
BytesRef v = TestUtil.randomBinaryTerm(random(), dim);
|
||||
// clip at -127 to avoid overflow
|
||||
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||
|
@ -180,22 +181,22 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
v.bytes[i] = -127;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
return v.bytes;
|
||||
}
|
||||
|
||||
public void testBasicDotProductBytes() {
|
||||
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
|
||||
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
|
||||
byte[] a = new byte[] {1, 2, 3};
|
||||
byte[] b = new byte[] {-10, 0, 5};
|
||||
assertEquals(5, VectorUtil.dotProduct(a, b), 0);
|
||||
float denom = a.length * (1 << 15);
|
||||
assertEquals(0.5 + 5 / denom, VectorUtil.dotProductScore(a, b), DELTA);
|
||||
|
||||
// dot product 0 maps to dotProductScore 0.5
|
||||
BytesRef zero = new BytesRef(new byte[] {0, 0, 0});
|
||||
byte[] zero = new byte[] {0, 0, 0};
|
||||
assertEquals(0.5, VectorUtil.dotProductScore(a, zero), DELTA);
|
||||
|
||||
BytesRef min = new BytesRef(new byte[] {-128, -128});
|
||||
BytesRef max = new BytesRef(new byte[] {127, 127});
|
||||
byte[] min = new byte[] {-128, -128};
|
||||
byte[] max = new byte[] {127, 127};
|
||||
// minimum dot product score is not quite zero because 127 < 128
|
||||
assertEquals(0.0039, VectorUtil.dotProductScore(min, max), DELTA);
|
||||
|
||||
|
@ -205,55 +206,48 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
|
||||
public void testSelfDotProductBytes() {
|
||||
// the dot product of a vector with itself is equal to the sum of the squares of its components
|
||||
BytesRef v = randomVectorBytes();
|
||||
byte[] v = randomVectorBytes();
|
||||
assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
|
||||
}
|
||||
|
||||
public void testOrthogonalDotProductBytes() {
|
||||
// the dot product of two perpendicular vectors is 0
|
||||
byte[] v = new byte[4];
|
||||
v[0] = (byte) random().nextInt(100);
|
||||
v[1] = (byte) random().nextInt(100);
|
||||
v[2] = v[1];
|
||||
v[3] = (byte) -v[0];
|
||||
// also test computing using BytesRef with nonzero offset
|
||||
assertEquals(0, VectorUtil.dotProduct(new BytesRef(v, 0, 2), new BytesRef(v, 2, 2)), DELTA);
|
||||
byte[] a = new byte[2];
|
||||
a[0] = (byte) random().nextInt(100);
|
||||
a[1] = (byte) random().nextInt(100);
|
||||
byte[] b = new byte[2];
|
||||
b[0] = a[1];
|
||||
b[1] = (byte) -a[0];
|
||||
assertEquals(0, VectorUtil.dotProduct(a, b), DELTA);
|
||||
}
|
||||
|
||||
public void testSelfSquareDistanceBytes() {
|
||||
// the l2 distance of a vector with itself is zero
|
||||
BytesRef v = randomVectorBytes();
|
||||
byte[] v = randomVectorBytes();
|
||||
assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
|
||||
}
|
||||
|
||||
public void testBasicSquareDistanceBytes() {
|
||||
assertEquals(
|
||||
12,
|
||||
VectorUtil.squareDistance(
|
||||
new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-1, 0, 5})),
|
||||
0);
|
||||
assertEquals(12, VectorUtil.squareDistance(new byte[] {1, 2, 3}, new byte[] {-1, 0, 5}), 0);
|
||||
}
|
||||
|
||||
public void testRandomSquareDistanceBytes() {
|
||||
// the square distance of a vector with its inverse is equal to four times the sum of squares of
|
||||
// its components
|
||||
BytesRef v = randomVectorBytes();
|
||||
BytesRef u = negative(v);
|
||||
byte[] v = randomVectorBytes();
|
||||
byte[] u = negative(v);
|
||||
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
|
||||
}
|
||||
|
||||
public void testBasicCosineBytes() {
|
||||
assertEquals(
|
||||
0.11952f,
|
||||
VectorUtil.cosine(new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-10, 0, 5})),
|
||||
DELTA);
|
||||
assertEquals(0.11952f, VectorUtil.cosine(new byte[] {1, 2, 3}, new byte[] {-10, 0, 5}), DELTA);
|
||||
}
|
||||
|
||||
public void testSelfCosineBytes() {
|
||||
// the dot product of a vector with itself is always equal to 1
|
||||
BytesRef v = randomVectorBytes();
|
||||
byte[] v = randomVectorBytes();
|
||||
// ensure the vector is non-zero so that cosine is defined
|
||||
v.bytes[0] = (byte) (random().nextInt(126) + 1);
|
||||
v[0] = (byte) (random().nextInt(126) + 1);
|
||||
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,6 @@ import org.apache.lucene.tests.util.LuceneTestCase;
|
|||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
@ -293,9 +292,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
NeighborQueue nn =
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
(BytesRef) getTargetVector(),
|
||||
(byte[]) getTargetVector(),
|
||||
10,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
(RandomAccessVectorValues<byte[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -346,9 +345,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
NeighborQueue nn =
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
(BytesRef) getTargetVector(),
|
||||
(byte[]) getTargetVector(),
|
||||
10,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
(RandomAccessVectorValues<byte[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -405,9 +404,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
(BytesRef) getTargetVector(),
|
||||
(byte[]) getTargetVector(),
|
||||
numAccepted,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
(RandomAccessVectorValues<byte[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -446,9 +445,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
createRandomAcceptOrds(0, nDoc),
|
||||
visitedLimit);
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
(BytesRef) getTargetVector(),
|
||||
(byte[]) getTargetVector(),
|
||||
topK,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
(RandomAccessVectorValues<byte[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -663,9 +662,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
actual =
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
(BytesRef) query,
|
||||
(byte[]) query,
|
||||
100,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors,
|
||||
(RandomAccessVectorValues<byte[]>) vectors,
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -689,9 +688,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
for (int j = 0; j < size; j++) {
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
if (getVectorEncoding() == VectorEncoding.BYTE) {
|
||||
assert query instanceof BytesRef;
|
||||
assert query instanceof byte[];
|
||||
expected.add(
|
||||
j, similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(j)));
|
||||
j, similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(j)));
|
||||
} else {
|
||||
assert query instanceof float[];
|
||||
expected.add(
|
||||
|
@ -789,17 +788,17 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
static class CircularByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
implements RandomAccessVectorValues<byte[]> {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
private final BytesRef bValue;
|
||||
private final byte[] bValue;
|
||||
|
||||
int doc = -1;
|
||||
|
||||
CircularByteVectorValues(int size) {
|
||||
this.size = size;
|
||||
value = new float[2];
|
||||
bValue = new BytesRef(new byte[2]);
|
||||
bValue = new byte[2];
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -818,7 +817,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() {
|
||||
public byte[] vectorValue() {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
|
@ -843,10 +842,10 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int ord) {
|
||||
public byte[] vectorValue(int ord) {
|
||||
unitVector2d(ord / (double) size, value);
|
||||
for (int i = 0; i < value.length; i++) {
|
||||
bValue.bytes[i] = (byte) (value[i] * 127);
|
||||
bValue[i] = (byte) (value[i] * 127);
|
||||
}
|
||||
return bValue;
|
||||
}
|
||||
|
@ -881,21 +880,17 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
break;
|
||||
}
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE:
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
((BytesRef) u.vectorValue()).bytes,
|
||||
((BytesRef) v.vectorValue()).bytes);
|
||||
break;
|
||||
case FLOAT32:
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
(float[]) u.vectorValue(),
|
||||
(float[]) v.vectorValue(),
|
||||
1e-4f);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
|
||||
case BYTE -> assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
(byte[]) u.vectorValue(),
|
||||
(byte[]) v.vectorValue());
|
||||
case FLOAT32 -> assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
(float[]) u.vectorValue(),
|
||||
(float[]) v.vectorValue(),
|
||||
1e-4f);
|
||||
default -> throw new IllegalArgumentException(
|
||||
"unknown vector encoding: " + getVectorEncoding());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,9 +71,7 @@ import org.apache.lucene.search.Weight;
|
|||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.IntroSorter;
|
||||
import org.apache.lucene.util.PrintStreamInfoStream;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
||||
|
@ -516,12 +514,10 @@ public class KnnGraphTester {
|
|||
|
||||
private static class VectorReaderByte extends VectorReader {
|
||||
private final byte[] scratch;
|
||||
private final BytesRef bytesRef;
|
||||
|
||||
VectorReaderByte(FileChannel input, int dim, int bufferSize) {
|
||||
super(input, dim, bufferSize);
|
||||
scratch = new byte[dim];
|
||||
bytesRef = new BytesRef(scratch);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -534,10 +530,10 @@ public class KnnGraphTester {
|
|||
return target;
|
||||
}
|
||||
|
||||
BytesRef nextBytes() throws IOException {
|
||||
byte[] nextBytes() throws IOException {
|
||||
readNext();
|
||||
bytes.get(scratch);
|
||||
return bytesRef;
|
||||
return scratch;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -750,40 +746,7 @@ public class KnnGraphTester {
|
|||
System.exit(1);
|
||||
}
|
||||
|
||||
static class NeighborArraySorter extends IntroSorter {
|
||||
private final int[] node;
|
||||
private final float[] score;
|
||||
|
||||
NeighborArraySorter(NeighborArray neighbors) {
|
||||
node = neighbors.node;
|
||||
score = neighbors.score;
|
||||
}
|
||||
|
||||
int pivot;
|
||||
|
||||
@Override
|
||||
protected void swap(int i, int j) {
|
||||
int tmpNode = node[i];
|
||||
float tmpScore = score[i];
|
||||
node[i] = node[j];
|
||||
score[i] = score[j];
|
||||
node[j] = tmpNode;
|
||||
score[j] = tmpScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void setPivot(int i) {
|
||||
pivot = i;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int comparePivot(int j) {
|
||||
return Float.compare(score[pivot], score[j]);
|
||||
}
|
||||
}
|
||||
|
||||
private static class BitSetQuery extends Query {
|
||||
|
||||
private final FixedBitSet docs;
|
||||
private final int cardinality;
|
||||
|
||||
|
|
|
@ -19,20 +19,16 @@ package org.apache.lucene.util.hnsw;
|
|||
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
|
||||
class MockByteVectorValues extends AbstractMockVectorValues<byte[]> {
|
||||
private final byte[] scratch;
|
||||
|
||||
static MockByteVectorValues fromValues(byte[][] byteValues) {
|
||||
int dimension = byteValues[0].length;
|
||||
BytesRef[] values = new BytesRef[byteValues.length];
|
||||
for (int i = 0; i < byteValues.length; i++) {
|
||||
values[i] = byteValues[i] == null ? null : new BytesRef(byteValues[i]);
|
||||
}
|
||||
BytesRef[] denseValues = new BytesRef[values.length];
|
||||
static MockByteVectorValues fromValues(byte[][] values) {
|
||||
int dimension = values[0].length;
|
||||
int maxDoc = values.length;
|
||||
byte[][] denseValues = new byte[maxDoc][];
|
||||
int count = 0;
|
||||
for (int i = 0; i < byteValues.length; i++) {
|
||||
for (int i = 0; i < maxDoc; i++) {
|
||||
if (values[i] != null) {
|
||||
denseValues[count++] = values[i];
|
||||
}
|
||||
|
@ -40,7 +36,7 @@ class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
|
|||
return new MockByteVectorValues(values, dimension, denseValues, count);
|
||||
}
|
||||
|
||||
MockByteVectorValues(BytesRef[] values, int dimension, BytesRef[] denseValues, int numVectors) {
|
||||
MockByteVectorValues(byte[][] values, int dimension, byte[][] denseValues, int numVectors) {
|
||||
super(values, dimension, denseValues, numVectors);
|
||||
scratch = new byte[dimension];
|
||||
}
|
||||
|
@ -55,7 +51,7 @@ class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() {
|
||||
public byte[] vectorValue() {
|
||||
if (LuceneTestCase.random().nextBoolean()) {
|
||||
return values[pos];
|
||||
} else {
|
||||
|
@ -63,8 +59,8 @@ class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
|
|||
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
|
||||
// twice in a
|
||||
// single computation.
|
||||
System.arraycopy(values[pos].bytes, values[pos].offset, scratch, 0, dimension);
|
||||
return new BytesRef(scratch);
|
||||
System.arraycopy(values[pos], 0, scratch, 0, dimension);
|
||||
return scratch;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,11 +30,10 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Before;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
|
||||
public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
|
@ -47,17 +46,17 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
|
|||
}
|
||||
|
||||
@Override
|
||||
Query knnQuery(String field, BytesRef vector, int k) {
|
||||
Query knnQuery(String field, byte[] vector, int k) {
|
||||
return new KnnByteVectorQuery(field, vector, k);
|
||||
}
|
||||
|
||||
@Override
|
||||
BytesRef randomVector(int dim) {
|
||||
return new BytesRef(randomVector8(random(), dim));
|
||||
byte[] randomVector(int dim) {
|
||||
return randomVector8(random(), dim);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<BytesRef> vectorValues(int size, int dimension) {
|
||||
AbstractMockVectorValues<byte[]> vectorValues(int size, int dimension) {
|
||||
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
|
||||
}
|
||||
|
||||
|
@ -66,7 +65,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
|
|||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<BytesRef> vectorValues(float[][] values) {
|
||||
AbstractMockVectorValues<byte[]> vectorValues(float[][] values) {
|
||||
byte[][] bValues = new byte[values.length][];
|
||||
// The case when all floats fit within a byte already.
|
||||
boolean scaleSimple = fitsInByte(values[0][0]);
|
||||
|
@ -87,32 +86,30 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
|
|||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<BytesRef> vectorValues(LeafReader reader, String fieldName)
|
||||
AbstractMockVectorValues<byte[]> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException {
|
||||
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
|
||||
byte[][] vectors = new byte[reader.maxDoc()][];
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
vectors[vectorValues.docID()] =
|
||||
ArrayUtil.copyOfSubArray(
|
||||
vectorValues.vectorValue().bytes,
|
||||
vectorValues.vectorValue().offset,
|
||||
vectorValues.vectorValue().offset + vectorValues.vectorValue().length);
|
||||
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
|
||||
}
|
||||
return MockByteVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field knnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||
Field knnVectorField(String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnByteVectorField(name, vector, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
RandomAccessVectorValues<BytesRef> circularVectorValues(int nDoc) {
|
||||
RandomAccessVectorValues<byte[]> circularVectorValues(int nDoc) {
|
||||
return new CircularByteVectorValues(nDoc);
|
||||
}
|
||||
|
||||
@Override
|
||||
BytesRef getTargetVector() {
|
||||
return new BytesRef(new byte[] {1, 0});
|
||||
byte[] getTargetVector() {
|
||||
return new byte[] {1, 0};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,7 +42,6 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Version;
|
||||
|
||||
/**
|
||||
|
@ -179,7 +178,7 @@ public class TermVectorLeafReader extends LeafReader {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -1408,7 +1408,7 @@ public class MemoryIndex {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ import org.apache.lucene.index.VectorValues;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
|
||||
public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
||||
|
@ -153,7 +152,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null
|
||||
|
|
|
@ -81,8 +81,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
@Override
|
||||
protected void addRandomFields(Document doc) {
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> doc.add(
|
||||
new KnnByteVectorField("v2", new BytesRef(randomVector8(30)), similarityFunction));
|
||||
case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction));
|
||||
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
|
||||
}
|
||||
}
|
||||
|
@ -632,9 +631,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
switch (fieldVectorEncodings[field]) {
|
||||
case BYTE -> {
|
||||
byte[] b = randomVector8(fieldDims[field]);
|
||||
doc.add(
|
||||
new KnnByteVectorField(
|
||||
fieldName, new BytesRef(b), fieldSimilarityFunctions[field]));
|
||||
doc.add(new KnnByteVectorField(fieldName, b, fieldSimilarityFunctions[field]));
|
||||
fieldTotals[field] += b[0];
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
|
@ -660,7 +657,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
if (byteVectorValues != null) {
|
||||
docCount += byteVectorValues.size();
|
||||
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += byteVectorValues.vectorValue().bytes[0];
|
||||
checksum += byteVectorValues.vectorValue()[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -766,10 +763,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
String fieldName = "field";
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
add(iw, fieldName, 1, 1, new BytesRef(new byte[] {-1, 0}));
|
||||
add(iw, fieldName, 4, 4, new BytesRef(new byte[] {0, 1}));
|
||||
add(iw, fieldName, 3, 3, (BytesRef) null);
|
||||
add(iw, fieldName, 2, 2, new BytesRef(new byte[] {1, 0}));
|
||||
add(iw, fieldName, 1, 1, new byte[] {-1, 0});
|
||||
add(iw, fieldName, 4, 4, new byte[] {0, 1});
|
||||
add(iw, fieldName, 3, 3, (byte[]) null);
|
||||
add(iw, fieldName, 2, 2, new byte[] {1, 0});
|
||||
iw.forceMerge(1);
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
LeafReader leaf = getOnlyLeafReader(reader);
|
||||
|
@ -779,11 +776,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
assertEquals(2, vectorValues.dimension());
|
||||
assertEquals(3, vectorValues.size());
|
||||
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(-1, vectorValues.vectorValue().bytes[0], 0);
|
||||
assertEquals(-1, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(1, vectorValues.vectorValue().bytes[0], 0);
|
||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(0, vectorValues.vectorValue().bytes[0], 0);
|
||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
}
|
||||
}
|
||||
|
@ -928,8 +925,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
BytesRef scratch = new BytesRef(dimension);
|
||||
scratch.length = dimension;
|
||||
byte[] scratch = new byte[dimension];
|
||||
int numValues = 0;
|
||||
BytesRef[] values = new BytesRef[numDoc];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
|
@ -940,10 +936,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
if (random().nextBoolean() && values[i] != null) {
|
||||
// sometimes use a shared scratch array
|
||||
System.arraycopy(values[i].bytes, 0, scratch.bytes, 0, dimension);
|
||||
System.arraycopy(values[i].bytes, 0, scratch, 0, dimension);
|
||||
add(iw, fieldName, i, scratch, similarityFunction);
|
||||
} else {
|
||||
add(iw, fieldName, i, values[i], similarityFunction);
|
||||
BytesRef value = values[i];
|
||||
add(iw, fieldName, i, value == null ? null : value.bytes, similarityFunction);
|
||||
}
|
||||
if (random().nextInt(10) == 2) {
|
||||
// sometimes delete a random document
|
||||
|
@ -971,12 +968,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
int docId;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
BytesRef v = vectorValues.vectorValue();
|
||||
byte[] v = vectorValues.vectorValue();
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
|
||||
assertEquals(idString, 0, values[id].compareTo(v));
|
||||
assertEquals(idString, 0, values[id].compareTo(new BytesRef(v)));
|
||||
++valueCount;
|
||||
} else {
|
||||
++numDeletes;
|
||||
|
@ -1141,12 +1138,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
|
||||
private void add(
|
||||
IndexWriter iw, String field, int id, BytesRef vector, VectorSimilarityFunction similarity)
|
||||
IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity)
|
||||
throws IOException {
|
||||
add(iw, field, id, random().nextInt(100), vector, similarity);
|
||||
}
|
||||
|
||||
private void add(IndexWriter iw, String field, int id, int sortKey, BytesRef vector)
|
||||
private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector)
|
||||
throws IOException {
|
||||
add(iw, field, id, sortKey, vector, VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
@ -1156,7 +1153,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
String field,
|
||||
int id,
|
||||
int sortKey,
|
||||
BytesRef vector,
|
||||
byte[] vector,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
Document doc = new Document();
|
||||
|
@ -1319,7 +1316,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
case BYTE -> {
|
||||
byte[] b = randomVector8(dim);
|
||||
fieldValuesCheckSum += b[0];
|
||||
doc.add(new KnnByteVectorField("knn_vector", new BytesRef(b), similarityFunction));
|
||||
doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction));
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
float[] v = randomVector(dim);
|
||||
|
@ -1349,7 +1346,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
docCount += byteVectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += byteVectorValues.vectorValue().bytes[0];
|
||||
checksum += byteVectorValues.vectorValue()[0];
|
||||
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
|
||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
|
|
|
@ -42,7 +42,6 @@ import org.apache.lucene.index.Terms;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* This is a hack to make index sorting fast, with a {@link LeafReader} that always returns merge
|
||||
|
@ -236,7 +235,7 @@ class MergeReaderWrapper extends LeafReader {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,6 @@ import org.apache.lucene.search.TopDocs;
|
|||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Version;
|
||||
import org.junit.Assert;
|
||||
|
||||
|
@ -243,7 +242,7 @@ public class QueryUtils {
|
|||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue