Unify how missing field entries are handle in knn formats (#13641)

During segment merge we must verify that a given field has vectors and exists. The typical knn format checks assume the per-field format is used and thus only check for `null`. 

But we should check for field existence in the field info and verify it has dense vectors

Additionally, this commit unifies how the knn formats work and they will throw if a non-existing field is queried. Except for PerField format, which will return null (like the other per field formats)
This commit is contained in:
Benjamin Trent 2024-09-11 16:46:38 -04:00 committed by GitHub
parent 8760654faa
commit 74e3c44063
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 205 additions and 21 deletions

View File

@ -452,6 +452,9 @@ Bug Fixes
* GITHUB#13703: Fix bug in LatLonPoint queries where narrow polygons close to latitude 90 don't
match any points due to an Integer overflow. (Ignacio Vera)
* GITHUB#13641: Unify how KnnFormats handle missing fields and correctly handle missing vector fields when
merging segments. (Ben Trent)
Build
---------------------

View File

@ -224,6 +224,9 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
}

View File

@ -218,6 +218,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
}

View File

@ -215,6 +215,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
}

View File

@ -233,6 +233,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
@ -248,6 +251,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""

View File

@ -241,6 +241,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
@ -264,6 +267,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""

View File

@ -78,4 +78,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
public void testEmptyByteVectorData() {
// unimplemented
}
@Override
public void testMergingWithDifferentByteKnnFields() {
// unimplemented
}
}

View File

@ -77,4 +77,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
public void testEmptyByteVectorData() {
// unimplemented
}
@Override
public void testMergingWithDifferentByteKnnFields() {
// unimplemented
}
}

View File

@ -67,4 +67,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
public void testEmptyByteVectorData() {
// unimplemented
}
@Override
public void testMergingWithDifferentByteKnnFields() {
// unimplemented
}
}

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene95;
import static org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
@ -476,8 +477,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
IncrementalHnswGraphMerger merger =
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
}
}
DocIdSetIterator mergedVectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {

View File

@ -28,6 +28,7 @@ import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
@ -212,14 +213,35 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
}
}
/**
* Returns true if the fieldInfos has vector values for the field.
*
* @param fieldInfos fieldInfos for the segment
* @param fieldName field name
* @return true if the fieldInfos has vector values for the field.
*/
public static boolean hasVectorValues(FieldInfos fieldInfos, String fieldName) {
if (fieldInfos.hasVectorValues() == false) {
return false;
}
FieldInfo info = fieldInfos.fieldInfo(fieldName);
return info != null && info.hasVectorValues();
}
private static <V, S> List<S> mergeVectorValues(
KnnVectorsReader[] knnVectorsReaders,
MergeState.DocMap[] docMaps,
FieldInfo mergingField,
FieldInfos[] sourceFieldInfos,
IOFunction<KnnVectorsReader, V> valuesSupplier,
BiFunction<MergeState.DocMap, V, S> newSub)
throws IOException {
List<S> subs = new ArrayList<>();
for (int i = 0; i < knnVectorsReaders.length; i++) {
FieldInfos sourceFieldInfo = sourceFieldInfos[i];
if (hasVectorValues(sourceFieldInfo, mergingField.name) == false) {
continue;
}
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
if (knnVectorsReader != null) {
V values = valuesSupplier.apply(knnVectorsReader);
@ -239,12 +261,10 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getFloatVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new FloatVectorValuesSub(docMap, values);
}),
fieldInfo,
mergeState.fieldInfos,
knnVectorsReader -> knnVectorsReader.getFloatVectorValues(fieldInfo.name),
FloatVectorValuesSub::new),
mergeState);
}
@ -256,12 +276,10 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getByteVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new ByteVectorValuesSub(docMap, values);
}),
fieldInfo,
mergeState.fieldInfos,
knnVectorsReader -> knnVectorsReader.getByteVectorValues(fieldInfo.name),
ByteVectorValuesSub::new),
mergeState);
}

View File

@ -174,6 +174,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
@ -197,6 +200,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""

View File

@ -17,6 +17,7 @@
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
@ -353,8 +354,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
: new TaskExecutor(mergeState.intraMergeTaskExecutor),
numMergeWorkers);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
}
}
DocIdSetIterator mergedVectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {

View File

@ -165,8 +165,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =

View File

@ -17,6 +17,7 @@
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
@ -630,7 +631,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
IntArrayList segmentSizes = new IntArrayList(mergeState.liveDocs.length);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
FloatVectorValues fvv;
if (mergeState.knnVectorsReaders[i] != null
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)
&& (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null
&& fvv.size() > 0) {
ScalarQuantizer quantizationState =
@ -928,8 +929,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
List<QuantizedByteVectorValueSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
if (mergeState.knnVectorsReaders[i] != null
&& mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name) != null) {
if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
QuantizedVectorsReader reader =
getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
assert scalarQuantizer != null;

View File

@ -27,11 +27,14 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
@ -53,6 +56,9 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.MergeScheduler;
import org.apache.lucene.index.MergeTrigger;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.SegmentWriteState;
@ -230,6 +236,106 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
public void testMergingWithDifferentKnnFields() throws Exception {
try (var dir = newDirectory()) {
IndexWriterConfig iwc = new IndexWriterConfig();
Codec codec = getCodec();
if (codec.knnVectorsFormat() instanceof PerFieldKnnVectorsFormat perFieldKnnVectorsFormat) {
final KnnVectorsFormat format =
perFieldKnnVectorsFormat.getKnnVectorsFormatForField("field");
iwc.setCodec(
new FilterCodec(codec.getName(), codec) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return format;
}
});
}
TestMergeScheduler mergeScheduler = new TestMergeScheduler();
iwc.setMergeScheduler(mergeScheduler);
iwc.setMergePolicy(new ForceMergePolicy(iwc.getMergePolicy()));
try (var writer = new IndexWriter(dir, iwc)) {
for (int i = 0; i < 10; i++) {
var doc = new Document();
doc.add(new KnnFloatVectorField("field", new float[] {i, i + 1, i + 2, i + 3}));
writer.addDocument(doc);
}
writer.commit();
for (int i = 0; i < 10; i++) {
var doc = new Document();
doc.add(new KnnFloatVectorField("otherVector", new float[] {i, i, i, i}));
writer.addDocument(doc);
}
writer.commit();
writer.forceMerge(1);
assertNull(mergeScheduler.ex.get());
}
}
}
public void testMergingWithDifferentByteKnnFields() throws Exception {
try (var dir = newDirectory()) {
IndexWriterConfig iwc = new IndexWriterConfig();
Codec codec = getCodec();
if (codec.knnVectorsFormat() instanceof PerFieldKnnVectorsFormat perFieldKnnVectorsFormat) {
final KnnVectorsFormat format =
perFieldKnnVectorsFormat.getKnnVectorsFormatForField("field");
iwc.setCodec(
new FilterCodec(codec.getName(), codec) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return format;
}
});
}
TestMergeScheduler mergeScheduler = new TestMergeScheduler();
iwc.setMergeScheduler(mergeScheduler);
iwc.setMergePolicy(new ForceMergePolicy(iwc.getMergePolicy()));
try (var writer = new IndexWriter(dir, iwc)) {
for (int i = 0; i < 10; i++) {
var doc = new Document();
doc.add(
new KnnByteVectorField("field", new byte[] {(byte) i, (byte) i, (byte) i, (byte) i}));
writer.addDocument(doc);
}
writer.commit();
for (int i = 0; i < 10; i++) {
var doc = new Document();
doc.add(
new KnnByteVectorField(
"otherVector", new byte[] {(byte) i, (byte) i, (byte) i, (byte) i}));
writer.addDocument(doc);
}
writer.commit();
writer.forceMerge(1);
assertNull(mergeScheduler.ex.get());
}
}
}
private static final class TestMergeScheduler extends MergeScheduler {
AtomicReference<Exception> ex = new AtomicReference<>();
@Override
public void merge(MergeSource mergeSource, MergeTrigger trigger) throws IOException {
while (true) {
MergePolicy.OneMerge merge = mergeSource.getNextMerge();
if (merge == null) {
break;
}
try {
mergeSource.merge(merge);
} catch (IllegalStateException | IllegalArgumentException e) {
ex.set(e);
break;
}
}
}
@Override
public void close() {}
}
@SuppressWarnings("unchecked")
public void testWriterRamEstimate() throws Exception {
final FieldInfos fieldInfos = new FieldInfos(new FieldInfo[0]);