Add checks in KNNVectorField / KNNVectorQuery to only allow non-null, non-empty and finite vectors (#12281)

---------

Co-authored-by: Uwe Schindler <uschindler@apache.org>
This commit is contained in:
Jonathan Ellis 2023-06-13 01:40:03 -07:00 committed by GitHub
parent 30eba6df56
commit 071461ece5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 66 additions and 16 deletions

View File

@ -170,6 +170,8 @@ Improvements
* GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit.
(Greg Miller)
* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler)
Optimizations
---------------------

View File

@ -17,6 +17,7 @@
package org.apache.lucene.document;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -100,7 +101,7 @@ public class KnnByteVectorField extends Field {
public KnnByteVectorField(
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = vector; // null-check done above
}
/**
@ -136,6 +137,11 @@ public class KnnByteVectorField extends Field {
+ " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
Objects.requireNonNull(vector, "vector value must not be null");
if (vector.length != fieldType.vectorDimension()) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
fieldsData = vector;
}

View File

@ -17,6 +17,7 @@
package org.apache.lucene.document;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -101,7 +102,7 @@ public class KnnFloatVectorField extends Field {
public KnnFloatVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = VectorUtil.checkFinite(vector); // null check done above
}
/**
@ -137,7 +138,12 @@ public class KnnFloatVectorField extends Field {
+ " using float[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
Objects.requireNonNull(vector, "vector value must not be null");
if (vector.length != fieldType.vectorDimension()) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
fieldsData = VectorUtil.checkFinite(vector);
}
/** Return the vector value of this field */

View File

@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query {
private final Query filter;
public AbstractKnnVectorQuery(String field, int k, Query filter) {
this.field = field;
this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);

View File

@ -71,7 +71,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = Objects.requireNonNull(target, "target");
}
@Override

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
@ -25,6 +26,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
/**
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
@ -70,7 +72,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}
@Override

View File

@ -34,7 +34,9 @@ public final class VectorUtil {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.dotProduct(a, b);
float r = PROVIDER.dotProduct(a, b);
assert Float.isFinite(r);
return r;
}
/**
@ -46,7 +48,9 @@ public final class VectorUtil {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.cosine(a, b);
float r = PROVIDER.cosine(a, b);
assert Float.isFinite(r);
return r;
}
/** Returns the cosine similarity between the two vectors. */
@ -66,7 +70,9 @@ public final class VectorUtil {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.squareDistance(a, b);
float r = PROVIDER.squareDistance(a, b);
assert Float.isFinite(r);
return r;
}
/** Returns the sum of squared differences of the two vectors. */
@ -154,4 +160,20 @@ public final class VectorUtil {
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;
}
/**
* Checks if a float vector only has finite components.
*
* @param v bytes containing a vector
* @return the vector for call-chaining
* @throws IllegalArgumentException if any component of vector is not finite
*/
public static float[] checkFinite(float[] v) {
for (int i = 0; i < v.length; i++) {
if (!Float.isFinite(v[i])) {
throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
}
}
return v;
}
}

View File

@ -102,7 +102,7 @@ final class VectorUtilDefaultProvider implements VectorUtilProvider {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
@Override

View File

@ -56,7 +56,10 @@ public class NeighborArray {
float previousScore = score[size - 1];
assert ((scoresDescOrder && (previousScore >= newScore))
|| (scoresDescOrder == false && (previousScore <= newScore)))
: "Nodes are added in the incorrect order!";
: "Nodes are added in the incorrect order! Comparing "
+ newScore
+ " to "
+ Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size));
}
node[size] = newNode;
score[size] = newScore;

View File

@ -217,7 +217,7 @@ final class VectorUtilPanamaProvider implements VectorUtilProvider {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
@Override

View File

@ -40,6 +40,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;
/**
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
@ -463,13 +464,21 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
ExitingReaderException.class,
() ->
leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);
leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE);
}
reader.close();

View File

@ -27,7 +27,7 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors.addInOrder(1, 0.8f);
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();
neighbors.insertSorted(3, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
@ -76,7 +76,7 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors.addInOrder(1, 0.3f);
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();
neighbors.insertSorted(3, 0.3f);
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);