Fix NPE when LeafReader return null VectorValues (#13162)

### Description
`LeafReader#getXXXVectorValues` may return null value.

**Reproduction**:
```
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
  public void testVectorEncodingMismatch() throws IOException {
    try (Directory indexStore =
        getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
        IndexReader reader = DirectoryReader.open(indexStore)) {
      AbstractKnnVectorQuery query =
          new KnnFloatVectorQuery("field", new float[] {0, 1}, 10);
      IndexSearcher searcher = newSearcher(reader);
      searcher.search(query, 10);
    }
  }
}
```
**Output**:
```
java.lang.NullPointerException: Cannot invoke "org.apache.lucene.index.FloatVectorValues.size()" because the return value of "org.apache.lucene.index.LeafReader.getFloatVectorValues(String)" is null
```
This commit is contained in:
panguixin 2024-03-11 20:07:04 +08:00 committed by GitHub
parent 6445bc0a14
commit 5b5815a26d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 182 additions and 51 deletions

View File

@ -230,6 +230,8 @@ Bug Fixes
* GITHUB#13154: Hunspell GeneratingSuggester: ensure there are never more than 100 roots to process (Peter Gromov)
* GITHUB#13162: Fix NPE when LeafReader return null VectorValues (Pan Guixin)
Other
---------------------

View File

@ -254,7 +254,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}

View File

@ -270,7 +270,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDocVectorValues,

View File

@ -232,7 +232,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,

View File

@ -54,4 +54,25 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
* @return the vector value
*/
public abstract byte[] vectorValue() throws IOException;
/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
public static void checkField(LeafReader in, String field) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ fi.getVectorEncoding()
+ ") for field "
+ field
+ "(expected="
+ VectorEncoding.BYTE
+ ")");
}
}
}

View File

@ -246,7 +246,9 @@ public abstract class CodecReader extends LeafReader {
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
// Field does not exist or does not index vectors
return;
}
@ -258,7 +260,9 @@ public abstract class CodecReader extends LeafReader {
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.BYTE) {
// Field does not exist or does not index vectors
return;
}

View File

@ -54,4 +54,25 @@ public abstract class FloatVectorValues extends DocIdSetIterator {
* @return the vector value
*/
public abstract float[] vectorValue() throws IOException;
/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
public static void checkField(LeafReader in, String field) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ fi.getVectorEncoding()
+ ") for field "
+ field
+ "(expected="
+ VectorEncoding.FLOAT32
+ ")");
}
}
}

View File

@ -246,11 +246,11 @@ public abstract non-sealed class LeafReader extends IndexReader {
public final TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
FloatVectorValues floatVectorValues = getFloatVectorValues(fi.name);
if (floatVectorValues == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
k = Math.min(k, getFloatVectorValues(fi.name).size());
k = Math.min(k, floatVectorValues.size());
if (k == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
@ -287,11 +287,11 @@ public abstract non-sealed class LeafReader extends IndexReader {
public final TopDocs searchNearestVectors(
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
ByteVectorValues byteVectorValues = getByteVectorValues(fi.name);
if (byteVectorValues == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
k = Math.min(k, getByteVectorValues(fi.name).size());
k = Math.min(k, byteVectorValues.size());
if (k == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}

View File

@ -187,6 +187,9 @@ abstract class AbstractKnnVectorQuery extends Query {
}
VectorScorer vectorScorer = createVectorScorer(context, fi);
if (vectorScorer == null) {
return NO_RESULTS;
}
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;

View File

@ -105,6 +105,9 @@ abstract class AbstractVectorSimilarityQuery extends Query {
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
if (results.scoreDocs.length == 0) {
return null;
}
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}
@ -148,6 +151,8 @@ abstract class AbstractVectorSimilarityQuery extends Query {
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else if (results.scoreDocs.length == 0) {
return null;
} else {
// Return an iterator over the collected results
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);

View File

@ -39,6 +39,10 @@ class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
ByteVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {

View File

@ -128,12 +128,13 @@ public class FieldExistsQuery extends Query {
break;
}
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
int numVectors =
DocIdSetIterator vectorValues =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> leaf.getFloatVectorValues(field).size();
case BYTE -> leaf.getByteVectorValues(field).size();
case FLOAT32 -> leaf.getFloatVectorValues(field);
case BYTE -> leaf.getByteVectorValues(field);
};
if (numVectors != leaf.maxDoc()) {
assert vectorValues != null : "unexpected null vector values";
if (vectorValues != null && vectorValues.cost() != leaf.maxDoc()) {
allReadersRewritable = false;
break;
}

View File

@ -40,6 +40,10 @@ class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues == null) {
FloatVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {

View File

@ -21,9 +21,9 @@ import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
@ -83,13 +83,13 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
@ -98,9 +98,6 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}
return VectorScorer.create(context, fi, target);
}

View File

@ -22,8 +22,8 @@ import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
@ -84,13 +84,13 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
@ -99,9 +99,6 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
return VectorScorer.create(context, fi, target);
}

View File

@ -41,6 +41,10 @@ abstract class VectorScorer {
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
throws IOException {
FloatVectorValues values = context.reader().getFloatVectorValues(fi.name);
if (values == null) {
FloatVectorValues.checkField(context.reader(), fi.name);
return null;
}
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new FloatVectorScorer(values, query, similarity);
}
@ -48,6 +52,10 @@ abstract class VectorScorer {
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
throws IOException {
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
if (values == null) {
ByteVectorValues.checkField(context.reader(), fi.name);
return null;
}
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new ByteVectorScorer(values, query, similarity);
}

View File

@ -87,6 +87,21 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
assertNotSame(queryVectorBytes, q1.getTargetCopy());
}
public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
Query filter = null;
if (random().nextBoolean()) {
filter = new MatchAllDocsQuery();
}
AbstractKnnVectorQuery query =
new KnnFloatVectorQuery("field", new float[] {0, 1}, 10, filter);
IndexSearcher searcher = newSearcher(reader);
expectThrows(IllegalStateException.class, () -> searcher.search(query, 10));
}
}
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {

View File

@ -79,6 +79,20 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
}
}
public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
Query filter = null;
if (random().nextBoolean()) {
filter = new MatchAllDocsQuery();
}
AbstractKnnVectorQuery query = new KnnByteVectorQuery("field", new byte[] {0, 1}, 10, filter);
IndexSearcher searcher = newSearcher(reader);
expectThrows(IllegalStateException.class, () -> searcher.search(query, 10));
}
}
public void testGetTarget() {
float[] queryVector = new float[] {0, 1};
KnnFloatVectorQuery q1 = new KnnFloatVectorQuery("f1", queryVector, 10);

View File

@ -22,7 +22,6 @@ import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
@ -80,21 +79,21 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return NO_RESULTS;
}
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
ParentBlockJoinByteVectorScorer vectorScorer =
new ParentBlockJoinByteVectorScorer(
context.reader().getByteVectorValues(field),
byteVectorValues,
acceptIterator,
parentBitSet,
query,

View File

@ -22,7 +22,6 @@ import java.util.Objects;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
@ -80,21 +79,21 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return NO_RESULTS;
}
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
DiversifyingChildrenFloatVectorScorer vectorScorer =
new DiversifyingChildrenFloatVectorScorer(
context.reader().getFloatVectorValues(field),
floatVectorValues,
acceptIterator,
parentBitSet,
query,

View File

@ -17,10 +17,17 @@
package org.apache.lucene.search.join;
import java.io.IOException;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
@ -46,6 +53,20 @@ public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVec
return new KnnByteVectorField(name, fromFloat(vector), vectorSimilarityFunction);
}
public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenFloatKnnVectorQuery(
"field", new float[] {1, 2}, filter, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
}
}
private static byte[] fromFloat(float[] queryVector) {
byte[] query = new byte[queryVector.length];
for (int i = 0; i < queryVector.length; i++) {

View File

@ -29,9 +29,11 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
@ -47,6 +49,20 @@ public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVe
fieldName, queryVector, childFilter, k, parentBitSet);
}
public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenByteKnnVectorQuery(
"field", new byte[] {1, 2}, filter, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
}
}
public void testScoreCosine() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w =