mirror of https://github.com/apache/lucene.git
Add checks in KNNVectorField / KNNVectorQuery to only allow non-null, non-empty and finite vectors (#12281)
--------- Co-authored-by: Uwe Schindler <uschindler@apache.org>
This commit is contained in:
parent
30eba6df56
commit
071461ece5
|
@ -170,6 +170,8 @@ Improvements
|
|||
* GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit.
|
||||
(Greg Miller)
|
||||
|
||||
* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler)
|
||||
|
||||
Optimizations
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.lucene.document;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -100,7 +101,7 @@ public class KnnByteVectorField extends Field {
|
|||
public KnnByteVectorField(
|
||||
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
super(name, createType(vector, similarityFunction));
|
||||
fieldsData = vector;
|
||||
fieldsData = vector; // null-check done above
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -136,6 +137,11 @@ public class KnnByteVectorField extends Field {
|
|||
+ " using byte[] but the field encoding is "
|
||||
+ fieldType.vectorEncoding());
|
||||
}
|
||||
Objects.requireNonNull(vector, "vector value must not be null");
|
||||
if (vector.length != fieldType.vectorDimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The number of vector dimensions does not match the field type");
|
||||
}
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.lucene.document;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -101,7 +102,7 @@ public class KnnFloatVectorField extends Field {
|
|||
public KnnFloatVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
super(name, createType(vector, similarityFunction));
|
||||
fieldsData = vector;
|
||||
fieldsData = VectorUtil.checkFinite(vector); // null check done above
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -137,7 +138,12 @@ public class KnnFloatVectorField extends Field {
|
|||
+ " using float[] but the field encoding is "
|
||||
+ fieldType.vectorEncoding());
|
||||
}
|
||||
fieldsData = vector;
|
||||
Objects.requireNonNull(vector, "vector value must not be null");
|
||||
if (vector.length != fieldType.vectorDimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The number of vector dimensions does not match the field type");
|
||||
}
|
||||
fieldsData = VectorUtil.checkFinite(vector);
|
||||
}
|
||||
|
||||
/** Return the vector value of this field */
|
||||
|
|
|
@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
private final Query filter;
|
||||
|
||||
public AbstractKnnVectorQuery(String field, int k, Query filter) {
|
||||
this.field = field;
|
||||
this.field = Objects.requireNonNull(field, "field");
|
||||
this.k = k;
|
||||
if (k < 1) {
|
||||
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||
|
|
|
@ -71,7 +71,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
*/
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
super(field, k, filter);
|
||||
this.target = target;
|
||||
this.target = Objects.requireNonNull(target, "target");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.search;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -25,6 +26,7 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
|
||||
|
@ -70,7 +72,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
|
|||
*/
|
||||
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
|
||||
super(field, k, filter);
|
||||
this.target = target;
|
||||
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -34,7 +34,9 @@ public final class VectorUtil {
|
|||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return PROVIDER.dotProduct(a, b);
|
||||
float r = PROVIDER.dotProduct(a, b);
|
||||
assert Float.isFinite(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -46,7 +48,9 @@ public final class VectorUtil {
|
|||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return PROVIDER.cosine(a, b);
|
||||
float r = PROVIDER.cosine(a, b);
|
||||
assert Float.isFinite(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
/** Returns the cosine similarity between the two vectors. */
|
||||
|
@ -66,7 +70,9 @@ public final class VectorUtil {
|
|||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return PROVIDER.squareDistance(a, b);
|
||||
float r = PROVIDER.squareDistance(a, b);
|
||||
assert Float.isFinite(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
/** Returns the sum of squared differences of the two vectors. */
|
||||
|
@ -154,4 +160,20 @@ public final class VectorUtil {
|
|||
float denom = (float) (a.length * (1 << 15));
|
||||
return 0.5f + dotProduct(a, b) / denom;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a float vector only has finite components.
|
||||
*
|
||||
* @param v bytes containing a vector
|
||||
* @return the vector for call-chaining
|
||||
* @throws IllegalArgumentException if any component of vector is not finite
|
||||
*/
|
||||
public static float[] checkFinite(float[] v) {
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
if (!Float.isFinite(v[i])) {
|
||||
throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,7 +102,7 @@ final class VectorUtilDefaultProvider implements VectorUtilProvider {
|
|||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -56,7 +56,10 @@ public class NeighborArray {
|
|||
float previousScore = score[size - 1];
|
||||
assert ((scoresDescOrder && (previousScore >= newScore))
|
||||
|| (scoresDescOrder == false && (previousScore <= newScore)))
|
||||
: "Nodes are added in the incorrect order!";
|
||||
: "Nodes are added in the incorrect order! Comparing "
|
||||
+ newScore
|
||||
+ " to "
|
||||
+ Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size));
|
||||
}
|
||||
node[size] = newNode;
|
||||
score[size] = newScore;
|
||||
|
|
|
@ -217,7 +217,7 @@ final class VectorUtilPanamaProvider implements VectorUtilProvider {
|
|||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -40,6 +40,7 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
|
||||
/**
|
||||
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
|
||||
|
@ -463,13 +464,21 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
ExitingReaderException.class,
|
||||
() ->
|
||||
leaf.searchNearestVectors(
|
||||
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
|
||||
"vector",
|
||||
TestVectorUtil.randomVector(dimension),
|
||||
5,
|
||||
leaf.getLiveDocs(),
|
||||
Integer.MAX_VALUE));
|
||||
} else {
|
||||
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
|
||||
scanAndRetrieve(leaf, iter);
|
||||
|
||||
leaf.searchNearestVectors(
|
||||
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
|
||||
"vector",
|
||||
TestVectorUtil.randomVector(dimension),
|
||||
5,
|
||||
leaf.getLiveDocs(),
|
||||
Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
reader.close();
|
||||
|
|
|
@ -27,7 +27,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
neighbors.addInOrder(1, 0.8f);
|
||||
|
||||
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
|
||||
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
|
||||
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();
|
||||
|
||||
neighbors.insertSorted(3, 0.9f);
|
||||
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
|
||||
|
@ -76,7 +76,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
neighbors.addInOrder(1, 0.3f);
|
||||
|
||||
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
|
||||
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
|
||||
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();
|
||||
|
||||
neighbors.insertSorted(3, 0.3f);
|
||||
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
|
||||
|
|
Loading…
Reference in New Issue