mirror of https://github.com/apache/lucene.git
Add safety checks to KnnVectorField; fixed issue with copying BytesRef (#1076)
This commit is contained in:
parent
9ae3498f82
commit
f9680c6807
|
@ -594,7 +594,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||
@Override
|
||||
public BytesRef copyValue(BytesRef value) {
|
||||
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
|
||||
return new BytesRef(
|
||||
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
|
|
|
@ -157,6 +157,13 @@ public class KnnVectorField extends Field {
|
|||
*/
|
||||
public KnnVectorField(String name, float[] vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
if (fieldType.vectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to create a vector for field "
|
||||
+ name
|
||||
+ " using float[] but the field encoding is "
|
||||
+ fieldType.vectorEncoding());
|
||||
}
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
|
@ -172,6 +179,13 @@ public class KnnVectorField extends Field {
|
|||
*/
|
||||
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to create a vector for field "
|
||||
+ name
|
||||
+ " using BytesRef but the field encoding is "
|
||||
+ fieldType.vectorEncoding());
|
||||
}
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,12 +21,10 @@ package org.apache.lucene.index;
|
|||
public enum VectorEncoding {
|
||||
|
||||
/**
|
||||
* Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity.
|
||||
* NOTE: this can enable significant storage savings and faster searches, at the cost of some
|
||||
* possible loss of precision. In order to use it, all vectors must be of the same norm, as
|
||||
* measured by the sum of the squares of the scalar values, and those values must be in the range
|
||||
* [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can
|
||||
* result in errors or poor search results.
|
||||
* Encodes vector using 8 bits of precision per sample. Values provided with higher precision (eg:
|
||||
* queries provided as float) *must* be in the range [-128, 127]. NOTE: this can enable
|
||||
* significant storage savings and faster searches, at the cost of some possible loss of
|
||||
* precision.
|
||||
*/
|
||||
BYTE(1),
|
||||
|
||||
|
@ -34,8 +32,8 @@ public enum VectorEncoding {
|
|||
FLOAT32(4);
|
||||
|
||||
/**
|
||||
* The number of bytes required to encode a scalar in this format. A vector will require dimension
|
||||
* * byteSize.
|
||||
* The number of bytes required to encode a scalar in this format. A vector will nominally require
|
||||
* dimension * byteSize bytes of storage.
|
||||
*/
|
||||
public final int byteSize;
|
||||
|
||||
|
|
|
@ -42,9 +42,10 @@ public enum VectorSimilarityFunction {
|
|||
|
||||
/**
|
||||
* Dot product. NOTE: this similarity is intended as an optimized way to perform cosine
|
||||
* similarity. In order to use it, all vectors must be of unit length, including both document and
|
||||
* query vectors. Using dot product with vectors that are not unit length can result in errors or
|
||||
* poor search results.
|
||||
* similarity. In order to use it, all vectors must be normalized, including both document and
|
||||
* query vectors. Using dot product with vectors that are not normalized can result in errors or
|
||||
* poor search results. Floating point vectors must be normalized to be of unit length, while byte
|
||||
* vectors should simply all have the same norm.
|
||||
*/
|
||||
DOT_PRODUCT {
|
||||
@Override
|
||||
|
|
|
@ -283,6 +283,10 @@ public final class VectorUtil {
|
|||
public static BytesRef toBytesRef(float[] vector) {
|
||||
BytesRef b = new BytesRef(new byte[vector.length]);
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
if (vector[i] < -128 || vector[i] > 127) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vector value at " + i + " is out of range [-128.127]: " + vector[i]);
|
||||
}
|
||||
b.bytes[i] = (byte) vector[i];
|
||||
}
|
||||
return b;
|
||||
|
|
|
@ -16,10 +16,16 @@
|
|||
*/
|
||||
package org.apache.lucene.document;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.StringReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
|
@ -506,6 +512,43 @@ public class TestField extends LuceneTestCase {
|
|||
dir.close();
|
||||
}
|
||||
|
||||
public void testKnnVectorField() throws Exception {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
BytesRef br = newBytesRef(new byte[5]);
|
||||
Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
|
||||
float[] vector = new float[] {1, 2};
|
||||
Field field2 = new KnnVectorField("float", vector);
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType()));
|
||||
assertEquals(br, field.binaryValue());
|
||||
doc.add(field);
|
||||
doc.add(field2);
|
||||
w.addDocument(doc);
|
||||
try (IndexReader r = DirectoryReader.open(w)) {
|
||||
VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary");
|
||||
assertEquals(1, binary.size());
|
||||
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
assertEquals(br, binary.binaryValue());
|
||||
assertNotNull(binary.vectorValue());
|
||||
assertEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
|
||||
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
|
||||
assertEquals(1, floatValues.size());
|
||||
assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc());
|
||||
assertNotNull(floatValues.binaryValue());
|
||||
assertEquals(vector.length, floatValues.vectorValue().length);
|
||||
assertEquals(vector[0], floatValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, floatValues.nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void trySetByteValue(Field f) {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
|
|
|
@ -244,4 +244,16 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
u[1] = -v[0];
|
||||
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
|
||||
}
|
||||
|
||||
public void testToBytesRef() {
|
||||
assertEquals(
|
||||
new BytesRef(new byte[] {-128, 0, 127}),
|
||||
VectorUtil.toBytesRef(new float[] {-128f, 0, 127f}));
|
||||
assertEquals(
|
||||
new BytesRef(new byte[] {-19, 0, 33}),
|
||||
VectorUtil.toBytesRef(new float[] {-19.9f, 0.5f, 33.7f}));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> VectorUtil.toBytesRef(new float[] {-128.1f}));
|
||||
expectThrows(IllegalArgumentException.class, () -> VectorUtil.toBytesRef(new float[] {127.1f}));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue