Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work correctly when a segment does not contain any docs with vectors (#13105)

This commit is contained in:
Chris Hostetter 2024-02-26 12:12:40 -07:00
parent 6eba1fb537
commit bf6f38665e
5 changed files with 83 additions and 9 deletions

View File

@ -206,7 +206,9 @@ Optimizations
Bug Fixes
---------------------
(No changes)
* GITHUB#13105: Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work correctly when a segment does not contain
any docs with vectors (hossman)
Other
---------------------

View File

@ -20,7 +20,9 @@ import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;
@ -39,11 +41,25 @@ public class ByteKnnVectorFieldSource extends ValueSource {
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {
final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName);
final LeafReader reader = readerContext.reader();
final ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
if (vectorValues == null) {
throw new IllegalArgumentException(
"no byte vector value is indexed for field '" + fieldName + "'");
VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.BYTE);
return new VectorFieldFunction(this) {
private final DocIdSetIterator empty = DocIdSetIterator.empty();
@Override
public byte[] byteVectorVal(int doc) throws IOException {
return null;
}
@Override
protected DocIdSetIterator getVectorIterator() {
return empty;
}
};
}
return new VectorFieldFunction(this) {

View File

@ -20,7 +20,9 @@ import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;
@ -39,12 +41,26 @@ public class FloatKnnVectorFieldSource extends ValueSource {
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {
final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName);
final LeafReader reader = readerContext.reader();
final FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName);
if (vectorValues == null) {
throw new IllegalArgumentException(
"no float vector value is indexed for field '" + fieldName + "'");
VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.FLOAT32);
return new VectorFieldFunction(this) {
private final DocIdSetIterator empty = DocIdSetIterator.empty();
@Override
public float[] floatVectorVal(int doc) throws IOException {
return null;
}
@Override
protected DocIdSetIterator getVectorIterator() {
return empty;
}
};
}
return new VectorFieldFunction(this) {
@Override

View File

@ -17,6 +17,9 @@
package org.apache.lucene.queries.function.valuesource;
import java.io.IOException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;
@ -53,4 +56,29 @@ public abstract class VectorFieldFunction extends FunctionValues {
}
return doc == curDocID;
}
/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} exists, but was not indexed with vectors.
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
static void checkField(LeafReader in, String field, VectorEncoding expectedEncoding) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null) {
final VectorEncoding actual = fi.hasVectorValues() ? fi.getVectorEncoding() : null;
if (expectedEncoding != actual) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ actual
+ ") for field "
+ field
+ "(expected="
+ expectedEncoding
+ ")");
}
}
}
}

View File

@ -78,6 +78,10 @@ public class TestKnnVectorSimilarityFunctions extends LuceneTestCase {
document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3}));
iw.addDocument(document);
if (usually(random())) {
iw.commit();
}
Document document2 = new Document();
document2.add(new StringField("id", "2", Field.Store.NO));
document2.add(new SortedDocValuesField("id", new BytesRef("2")));
@ -232,7 +236,7 @@ public class TestKnnVectorSimilarityFunctions extends LuceneTestCase {
new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
assertThrows(
IllegalArgumentException.class,
IllegalStateException.class,
() -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10));
v1 = new FloatKnnVectorFieldSource("knnByteField1");
@ -241,8 +245,16 @@ public class TestKnnVectorSimilarityFunctions extends LuceneTestCase {
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
assertThrows(
IllegalArgumentException.class,
IllegalStateException.class,
() -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10));
v1 = new FloatKnnVectorFieldSource("id");
FloatVectorSimilarityFunction idVectorSimilarityFunction =
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
assertThrows(
IllegalStateException.class,
() -> searcher.search(new FunctionQuery(idVectorSimilarityFunction), 10));
}
private static void assertHits(Query q, float[] scores) throws Exception {