From 071461ece541e919ab9a6addc6b45e4093e38299 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 13 Jun 2023 01:40:03 -0700 Subject: [PATCH] Add checks in KNNVectorField / KNNVectorQuery to only allow non-null, non-empty and finite vectors (#12281) --------- Co-authored-by: Uwe Schindler --- lucene/CHANGES.txt | 2 ++ .../lucene/document/KnnByteVectorField.java | 8 +++++- .../lucene/document/KnnFloatVectorField.java | 10 +++++-- .../lucene/search/AbstractKnnVectorQuery.java | 2 +- .../lucene/search/KnnByteVectorQuery.java | 2 +- .../lucene/search/KnnFloatVectorQuery.java | 4 ++- .../org/apache/lucene/util/VectorUtil.java | 28 +++++++++++++++++-- .../util/VectorUtilDefaultProvider.java | 2 +- .../lucene/util/hnsw/NeighborArray.java | 5 +++- .../lucene/util/VectorUtilPanamaProvider.java | 2 +- .../index/TestExitableDirectoryReader.java | 13 +++++++-- .../lucene/util/hnsw/TestNeighborArray.java | 4 +-- 12 files changed, 66 insertions(+), 16 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 513edda2aba..8b82e3c79ea 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -170,6 +170,8 @@ Improvements * GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit. (Greg Miller) +* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler) + Optimizations --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java index fabbc5259e3..87cb6a9f056 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -100,7 +101,7 @@ public class KnnByteVectorField extends Field { public KnnByteVectorField( String name, byte[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = vector; // null-check done above } /** @@ -136,6 +137,11 @@ public class KnnByteVectorField extends Field { + " using byte[] but the field encoding is " + fieldType.vectorEncoding()); } + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } fieldsData = vector; } diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index d6673293c72..9d1cd02c013 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -101,7 +102,7 @@ public class KnnFloatVectorField extends Field { public KnnFloatVectorField( String name, float[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = VectorUtil.checkFinite(vector); // null check done above } /** @@ -137,7 +138,12 @@ public class KnnFloatVectorField extends Field { + " using float[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = vector; + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } + fieldsData = VectorUtil.checkFinite(vector); } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index cd8d73b8c26..eb51b623831 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query { private final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { - this.field = field; + this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 4ec617c2447..10345cd7adf 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -71,7 +71,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = Objects.requireNonNull(target, "target"); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 2b1b3a69582..3036e7c4516 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -18,6 +18,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.Arrays; +import java.util.Objects; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FieldInfo; @@ -25,6 +26,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; /** * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest @@ -70,7 +72,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 068a6edc035..c9e1d368334 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -34,7 +34,9 @@ public final class VectorUtil { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.dotProduct(a, b); + float r = PROVIDER.dotProduct(a, b); + assert Float.isFinite(r); + return r; } /** @@ -46,7 +48,9 @@ public final class VectorUtil { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.cosine(a, b); + float r = PROVIDER.cosine(a, b); + assert Float.isFinite(r); + return r; } /** Returns the cosine similarity between the two vectors. */ @@ -66,7 +70,9 @@ public final class VectorUtil { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.squareDistance(a, b); + float r = PROVIDER.squareDistance(a, b); + assert Float.isFinite(r); + return r; } /** Returns the sum of squared differences of the two vectors. */ @@ -154,4 +160,20 @@ public final class VectorUtil { float denom = (float) (a.length * (1 << 15)); return 0.5f + dotProduct(a, b) / denom; } + + /** + * Checks if a float vector only has finite components. + * + * @param v bytes containing a vector + * @return the vector for call-chaining + * @throws IllegalArgumentException if any component of vector is not finite + */ + public static float[] checkFinite(float[] v) { + for (int i = 0; i < v.length; i++) { + if (!Float.isFinite(v[i])) { + throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]); + } + } + return v; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java index da8483ed04d..665181e8678 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -102,7 +102,7 @@ final class VectorUtilDefaultProvider implements VectorUtilProvider { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index a23b9b5254e..b44f7da8b8a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -56,7 +56,10 @@ public class NeighborArray { float previousScore = score[size - 1]; assert ((scoresDescOrder && (previousScore >= newScore)) || (scoresDescOrder == false && (previousScore <= newScore))) - : "Nodes are added in the incorrect order!"; + : "Nodes are added in the incorrect order! Comparing " + + newScore + + " to " + + Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size)); } node[size] = newNode; score[size] = newScore; diff --git a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java index fd599c232fb..61ec15e0d22 100644 --- a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java +++ b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java @@ -217,7 +217,7 @@ final class VectorUtilPanamaProvider implements VectorUtilProvider { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 5dc11a52fb4..4a6365b5309 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -40,6 +40,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TestVectorUtil; /** * Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running @@ -463,13 +464,21 @@ public class TestExitableDirectoryReader extends LuceneTestCase { ExitingReaderException.class, () -> leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE)); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE)); } else { DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); scanAndRetrieve(leaf, iter); leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE); } reader.close(); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index 039f69c9dc4..c81077aa6da 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -27,7 +27,7 @@ public class TestNeighborArray extends LuceneTestCase { neighbors.addInOrder(1, 0.8f); AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); @@ -76,7 +76,7 @@ public class TestNeighborArray extends LuceneTestCase { neighbors.addInOrder(1, 0.3f); AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.3f); assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);