Add safety checks to KnnVectorField; fixed issue with copying BytesRef (#1076)

This commit is contained in:
Michael Sokolov 2022-08-20 08:38:42 -04:00 committed by GitHub
parent 9ae3498f82
commit f9680c6807
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 85 additions and 12 deletions

View File

@ -594,7 +594,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
return new BytesRef(
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {

View File

@ -157,6 +157,13 @@ public class KnnVectorField extends Field {
*/
public KnnVectorField(String name, float[] vector, FieldType fieldType) {
super(name, fieldType);
if (fieldType.vectorEncoding() != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"Attempt to create a vector for field "
+ name
+ " using float[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
}
@ -172,6 +179,13 @@ public class KnnVectorField extends Field {
*/
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
super(name, fieldType);
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"Attempt to create a vector for field "
+ name
+ " using BytesRef but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
}

View File

@ -21,12 +21,10 @@ package org.apache.lucene.index;
public enum VectorEncoding {
/**
* Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity.
* NOTE: this can enable significant storage savings and faster searches, at the cost of some
* possible loss of precision. In order to use it, all vectors must be of the same norm, as
* measured by the sum of the squares of the scalar values, and those values must be in the range
* [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can
* result in errors or poor search results.
* Encodes vector using 8 bits of precision per sample. Values provided with higher precision (eg:
* queries provided as float) *must* be in the range [-128, 127]. NOTE: this can enable
* significant storage savings and faster searches, at the cost of some possible loss of
* precision.
*/
BYTE(1),
@ -34,8 +32,8 @@ public enum VectorEncoding {
FLOAT32(4);
/**
* The number of bytes required to encode a scalar in this format. A vector will require dimension
* * byteSize.
* The number of bytes required to encode a scalar in this format. A vector will nominally require
* dimension * byteSize bytes of storage.
*/
public final int byteSize;

View File

@ -42,9 +42,10 @@ public enum VectorSimilarityFunction {
/**
* Dot product. NOTE: this similarity is intended as an optimized way to perform cosine
* similarity. In order to use it, all vectors must be of unit length, including both document and
* query vectors. Using dot product with vectors that are not unit length can result in errors or
* poor search results.
* similarity. In order to use it, all vectors must be normalized, including both document and
* query vectors. Using dot product with vectors that are not normalized can result in errors or
* poor search results. Floating point vectors must be normalized to be of unit length, while byte
* vectors should simply all have the same norm.
*/
DOT_PRODUCT {
@Override

View File

@ -283,6 +283,10 @@ public final class VectorUtil {
public static BytesRef toBytesRef(float[] vector) {
BytesRef b = new BytesRef(new byte[vector.length]);
for (int i = 0; i < vector.length; i++) {
if (vector[i] < -128 || vector[i] > 127) {
throw new IllegalArgumentException(
"Vector value at " + i + " is out of range [-128.127]: " + vector[i]);
}
b.bytes[i] = (byte) vector[i];
}
return b;

View File

@ -16,10 +16,16 @@
*/
package org.apache.lucene.document;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
@ -506,6 +512,43 @@ public class TestField extends LuceneTestCase {
dir.close();
}
public void testKnnVectorField() throws Exception {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
BytesRef br = newBytesRef(new byte[5]);
Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
expectThrows(
IllegalArgumentException.class,
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
float[] vector = new float[] {1, 2};
Field field2 = new KnnVectorField("float", vector);
expectThrows(
IllegalArgumentException.class,
() -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType()));
assertEquals(br, field.binaryValue());
doc.add(field);
doc.add(field2);
w.addDocument(doc);
try (IndexReader r = DirectoryReader.open(w)) {
VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary");
assertEquals(1, binary.size());
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
assertEquals(br, binary.binaryValue());
assertNotNull(binary.vectorValue());
assertEquals(NO_MORE_DOCS, binary.nextDoc());
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
assertEquals(1, floatValues.size());
assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc());
assertNotNull(floatValues.binaryValue());
assertEquals(vector.length, floatValues.vectorValue().length);
assertEquals(vector[0], floatValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, floatValues.nextDoc());
}
}
}
private void trySetByteValue(Field f) {
expectThrows(
IllegalArgumentException.class,

View File

@ -244,4 +244,16 @@ public class TestVectorUtil extends LuceneTestCase {
u[1] = -v[0];
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
}
public void testToBytesRef() {
assertEquals(
new BytesRef(new byte[] {-128, 0, 127}),
VectorUtil.toBytesRef(new float[] {-128f, 0, 127f}));
assertEquals(
new BytesRef(new byte[] {-19, 0, 33}),
VectorUtil.toBytesRef(new float[] {-19.9f, 0.5f, 33.7f}));
expectThrows(
IllegalArgumentException.class, () -> VectorUtil.toBytesRef(new float[] {-128.1f}));
expectThrows(IllegalArgumentException.class, () -> VectorUtil.toBytesRef(new float[] {127.1f}));
}
}