mirror of https://github.com/apache/lucene.git
Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work correctly when a segment does not contain any docs with vectors (#13105)
This commit is contained in:
parent
6eba1fb537
commit
bf6f38665e
|
@ -206,7 +206,9 @@ Optimizations
|
|||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
* GITHUB#13105: Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work correctly when a segment does not contain
|
||||
any docs with vectors (hossman)
|
||||
|
||||
Other
|
||||
---------------------
|
||||
|
|
|
@ -20,7 +20,9 @@ import java.io.IOException;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.queries.function.FunctionValues;
|
||||
import org.apache.lucene.queries.function.ValueSource;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -39,11 +41,25 @@ public class ByteKnnVectorFieldSource extends ValueSource {
|
|||
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
|
||||
throws IOException {
|
||||
|
||||
final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName);
|
||||
final LeafReader reader = readerContext.reader();
|
||||
final ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
|
||||
|
||||
if (vectorValues == null) {
|
||||
throw new IllegalArgumentException(
|
||||
"no byte vector value is indexed for field '" + fieldName + "'");
|
||||
VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.BYTE);
|
||||
|
||||
return new VectorFieldFunction(this) {
|
||||
private final DocIdSetIterator empty = DocIdSetIterator.empty();
|
||||
|
||||
@Override
|
||||
public byte[] byteVectorVal(int doc) throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DocIdSetIterator getVectorIterator() {
|
||||
return empty;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return new VectorFieldFunction(this) {
|
||||
|
|
|
@ -20,7 +20,9 @@ import java.io.IOException;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.queries.function.FunctionValues;
|
||||
import org.apache.lucene.queries.function.ValueSource;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -39,12 +41,26 @@ public class FloatKnnVectorFieldSource extends ValueSource {
|
|||
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
|
||||
throws IOException {
|
||||
|
||||
final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName);
|
||||
final LeafReader reader = readerContext.reader();
|
||||
final FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName);
|
||||
|
||||
if (vectorValues == null) {
|
||||
throw new IllegalArgumentException(
|
||||
"no float vector value is indexed for field '" + fieldName + "'");
|
||||
VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.FLOAT32);
|
||||
return new VectorFieldFunction(this) {
|
||||
private final DocIdSetIterator empty = DocIdSetIterator.empty();
|
||||
|
||||
@Override
|
||||
public float[] floatVectorVal(int doc) throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DocIdSetIterator getVectorIterator() {
|
||||
return empty;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return new VectorFieldFunction(this) {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
package org.apache.lucene.queries.function.valuesource;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.queries.function.FunctionValues;
|
||||
import org.apache.lucene.queries.function.ValueSource;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -53,4 +56,29 @@ public abstract class VectorFieldFunction extends FunctionValues {
|
|||
}
|
||||
return doc == curDocID;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks the Vector Encoding of a field
|
||||
*
|
||||
* @throws IllegalStateException if {@code field} exists, but was not indexed with vectors.
|
||||
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
|
||||
* @lucene.internal
|
||||
* @lucene.experimental
|
||||
*/
|
||||
static void checkField(LeafReader in, String field, VectorEncoding expectedEncoding) {
|
||||
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
|
||||
if (fi != null) {
|
||||
final VectorEncoding actual = fi.hasVectorValues() ? fi.getVectorEncoding() : null;
|
||||
if (expectedEncoding != actual) {
|
||||
throw new IllegalStateException(
|
||||
"Unexpected vector encoding ("
|
||||
+ actual
|
||||
+ ") for field "
|
||||
+ field
|
||||
+ "(expected="
|
||||
+ expectedEncoding
|
||||
+ ")");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,6 +78,10 @@ public class TestKnnVectorSimilarityFunctions extends LuceneTestCase {
|
|||
document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3}));
|
||||
iw.addDocument(document);
|
||||
|
||||
if (usually(random())) {
|
||||
iw.commit();
|
||||
}
|
||||
|
||||
Document document2 = new Document();
|
||||
document2.add(new StringField("id", "2", Field.Store.NO));
|
||||
document2.add(new SortedDocValuesField("id", new BytesRef("2")));
|
||||
|
@ -232,7 +236,7 @@ public class TestKnnVectorSimilarityFunctions extends LuceneTestCase {
|
|||
new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
|
||||
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
IllegalStateException.class,
|
||||
() -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10));
|
||||
|
||||
v1 = new FloatKnnVectorFieldSource("knnByteField1");
|
||||
|
@ -241,8 +245,16 @@ public class TestKnnVectorSimilarityFunctions extends LuceneTestCase {
|
|||
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
|
||||
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
IllegalStateException.class,
|
||||
() -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10));
|
||||
|
||||
v1 = new FloatKnnVectorFieldSource("id");
|
||||
FloatVectorSimilarityFunction idVectorSimilarityFunction =
|
||||
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
|
||||
|
||||
assertThrows(
|
||||
IllegalStateException.class,
|
||||
() -> searcher.search(new FunctionQuery(idVectorSimilarityFunction), 10));
|
||||
}
|
||||
|
||||
private static void assertHits(Query q, float[] scores) throws Exception {
|
||||
|
|
Loading…
Reference in New Issue