From f9680c68075b89bc497e6bb0fb2e104565702097 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Sat, 20 Aug 2022 08:38:42 -0400 Subject: [PATCH] Add safety checks to KnnVectorField; fixed issue with copying BytesRef (#1076) --- .../lucene94/Lucene94HnswVectorsWriter.java | 3 +- .../lucene/document/KnnVectorField.java | 14 ++++++ .../apache/lucene/index/VectorEncoding.java | 14 +++--- .../index/VectorSimilarityFunction.java | 7 +-- .../org/apache/lucene/util/VectorUtil.java | 4 ++ .../org/apache/lucene/document/TestField.java | 43 +++++++++++++++++++ .../apache/lucene/util/TestVectorUtil.java | 12 ++++++ 7 files changed, 85 insertions(+), 12 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java index 6e3cf62ba75..c83fc58e1ed 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -594,7 +594,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { case BYTE -> new FieldWriter(fieldInfo, M, beamWidth, infoStream) { @Override public BytesRef copyValue(BytesRef value) { - return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim)); + return new BytesRef( + ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim)); } }; case FLOAT32 -> new FieldWriter(fieldInfo, M, beamWidth, infoStream) { diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java index 1ddb025592a..2376f1806d3 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java @@ -157,6 +157,13 @@ public class KnnVectorField extends Field { */ public KnnVectorField(String name, float[] vector, FieldType fieldType) { super(name, fieldType); + if (fieldType.vectorEncoding() != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "Attempt to create a vector for field " + + name + + " using float[] but the field encoding is " + + fieldType.vectorEncoding()); + } fieldsData = vector; } @@ -172,6 +179,13 @@ public class KnnVectorField extends Field { */ public KnnVectorField(String name, BytesRef vector, FieldType fieldType) { super(name, fieldType); + if (fieldType.vectorEncoding() != VectorEncoding.BYTE) { + throw new IllegalArgumentException( + "Attempt to create a vector for field " + + name + + " using BytesRef but the field encoding is " + + fieldType.vectorEncoding()); + } fieldsData = vector; } diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java b/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java index f527aee061d..8ae6dd40e34 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java @@ -21,12 +21,10 @@ package org.apache.lucene.index; public enum VectorEncoding { /** - * Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity. - * NOTE: this can enable significant storage savings and faster searches, at the cost of some - * possible loss of precision. In order to use it, all vectors must be of the same norm, as - * measured by the sum of the squares of the scalar values, and those values must be in the range - * [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can - * result in errors or poor search results. + * Encodes vector using 8 bits of precision per sample. Values provided with higher precision (eg: + * queries provided as float) *must* be in the range [-128, 127]. NOTE: this can enable + * significant storage savings and faster searches, at the cost of some possible loss of + * precision. */ BYTE(1), @@ -34,8 +32,8 @@ public enum VectorEncoding { FLOAT32(4); /** - * The number of bytes required to encode a scalar in this format. A vector will require dimension - * * byteSize. + * The number of bytes required to encode a scalar in this format. A vector will nominally require + * dimension * byteSize bytes of storage. */ public final int byteSize; diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index f21a27c1511..ad793facaeb 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -42,9 +42,10 @@ public enum VectorSimilarityFunction { /** * Dot product. NOTE: this similarity is intended as an optimized way to perform cosine - * similarity. In order to use it, all vectors must be of unit length, including both document and - * query vectors. Using dot product with vectors that are not unit length can result in errors or - * poor search results. + * similarity. In order to use it, all vectors must be normalized, including both document and + * query vectors. Using dot product with vectors that are not normalized can result in errors or + * poor search results. Floating point vectors must be normalized to be of unit length, while byte + * vectors should simply all have the same norm. */ DOT_PRODUCT { @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 132de83fca3..5e7c313fd7f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -283,6 +283,10 @@ public final class VectorUtil { public static BytesRef toBytesRef(float[] vector) { BytesRef b = new BytesRef(new byte[vector.length]); for (int i = 0; i < vector.length; i++) { + if (vector[i] < -128 || vector[i] > 127) { + throw new IllegalArgumentException( + "Vector value at " + i + " is out of range [-128.127]: " + vector[i]); + } b.bytes[i] = (byte) vector[i]; } return b; diff --git a/lucene/core/src/test/org/apache/lucene/document/TestField.java b/lucene/core/src/test/org/apache/lucene/document/TestField.java index de57596912d..781f2b613c6 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestField.java @@ -16,10 +16,16 @@ */ package org.apache.lucene.document; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import java.io.StringReader; import java.nio.charset.StandardCharsets; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; @@ -506,6 +512,43 @@ public class TestField extends LuceneTestCase { dir.close(); } + public void testKnnVectorField() throws Exception { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + BytesRef br = newBytesRef(new byte[5]); + Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN); + expectThrows( + IllegalArgumentException.class, + () -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType())); + float[] vector = new float[] {1, 2}; + Field field2 = new KnnVectorField("float", vector); + expectThrows( + IllegalArgumentException.class, + () -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType())); + assertEquals(br, field.binaryValue()); + doc.add(field); + doc.add(field2); + w.addDocument(doc); + try (IndexReader r = DirectoryReader.open(w)) { + VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary"); + assertEquals(1, binary.size()); + assertNotEquals(NO_MORE_DOCS, binary.nextDoc()); + assertEquals(br, binary.binaryValue()); + assertNotNull(binary.vectorValue()); + assertEquals(NO_MORE_DOCS, binary.nextDoc()); + + VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float"); + assertEquals(1, floatValues.size()); + assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc()); + assertNotNull(floatValues.binaryValue()); + assertEquals(vector.length, floatValues.vectorValue().length); + assertEquals(vector[0], floatValues.vectorValue()[0], 0); + assertEquals(NO_MORE_DOCS, floatValues.nextDoc()); + } + } + } + private void trySetByteValue(Field f) { expectThrows( IllegalArgumentException.class, diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index d5033e3ee3c..9a1a15db018 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -244,4 +244,16 @@ public class TestVectorUtil extends LuceneTestCase { u[1] = -v[0]; assertEquals(0, VectorUtil.cosine(u, v), DELTA); } + + public void testToBytesRef() { + assertEquals( + new BytesRef(new byte[] {-128, 0, 127}), + VectorUtil.toBytesRef(new float[] {-128f, 0, 127f})); + assertEquals( + new BytesRef(new byte[] {-19, 0, 33}), + VectorUtil.toBytesRef(new float[] {-19.9f, 0.5f, 33.7f})); + expectThrows( + IllegalArgumentException.class, () -> VectorUtil.toBytesRef(new float[] {-128.1f})); + expectThrows(IllegalArgumentException.class, () -> VectorUtil.toBytesRef(new float[] {127.1f})); + } }