diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 193d697dff0..010702cf44d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -452,6 +452,9 @@ Bug Fixes * GITHUB#13703: Fix bug in LatLonPoint queries where narrow polygons close to latitude 90 don't match any points due to an Integer overflow. (Ignacio Vera) +* GITHUB#13641: Unify how KnnFormats handle missing fields and correctly handle missing vector fields when + merging segments. (Ben Trent) + Build --------------------- diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index ab2486f4518..665d3140321 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -224,6 +224,9 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } return getOffHeapVectorValues(fieldEntry); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 048280466d4..81f8d97a9a0 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -218,6 +218,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } return getOffHeapVectorValues(fieldEntry); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 833efdf8025..39fe109a9f1 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -215,6 +215,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } return OffHeapFloatVectorValues.load(fieldEntry, vectorData); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index a948ab7bee3..d5beae1e681 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -233,6 +233,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { throw new IllegalArgumentException( "field=\"" @@ -248,6 +251,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { throw new IllegalArgumentException( "field=\"" diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 1b74ff94c18..2e6714d6eb8 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -241,6 +241,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { throw new IllegalArgumentException( "field=\"" @@ -264,6 +267,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { throw new IllegalArgumentException( "field=\"" diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java index 720e1f56468..2c689d5c0e5 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java @@ -78,4 +78,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testEmptyByteVectorData() { // unimplemented } + + @Override + public void testMergingWithDifferentByteKnnFields() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java index 09b5a50b4bc..df79316db0a 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java @@ -77,4 +77,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testEmptyByteVectorData() { // unimplemented } + + @Override + public void testMergingWithDifferentByteKnnFields() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java index 5189791ef17..0e003dafc3b 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java @@ -67,4 +67,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testEmptyByteVectorData() { // unimplemented } + + @Override + public void testMergingWithDifferentByteKnnFields() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index 6b0075d45df..37c39d311d6 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -18,6 +18,7 @@ package org.apache.lucene.backward_codecs.lucene95; import static org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; @@ -476,8 +477,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { IncrementalHnswGraphMerger merger = new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth); for (int i = 0; i < mergeState.liveDocs.length; i++) { - merger.addReader( - mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) { + merger.addReader( + mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); + } } DocIdSetIterator mergedVectorIterator = null; switch (fieldInfo.getVectorEncoding()) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 28f9995b11e..3b185fd13a0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -28,6 +28,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; @@ -212,14 +213,35 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { } } + /** + * Returns true if the fieldInfos has vector values for the field. + * + * @param fieldInfos fieldInfos for the segment + * @param fieldName field name + * @return true if the fieldInfos has vector values for the field. + */ + public static boolean hasVectorValues(FieldInfos fieldInfos, String fieldName) { + if (fieldInfos.hasVectorValues() == false) { + return false; + } + FieldInfo info = fieldInfos.fieldInfo(fieldName); + return info != null && info.hasVectorValues(); + } + private static List mergeVectorValues( KnnVectorsReader[] knnVectorsReaders, MergeState.DocMap[] docMaps, + FieldInfo mergingField, + FieldInfos[] sourceFieldInfos, IOFunction valuesSupplier, BiFunction newSub) throws IOException { List subs = new ArrayList<>(); for (int i = 0; i < knnVectorsReaders.length; i++) { + FieldInfos sourceFieldInfo = sourceFieldInfos[i]; + if (hasVectorValues(sourceFieldInfo, mergingField.name) == false) { + continue; + } KnnVectorsReader knnVectorsReader = knnVectorsReaders[i]; if (knnVectorsReader != null) { V values = valuesSupplier.apply(knnVectorsReader); @@ -239,12 +261,10 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { mergeVectorValues( mergeState.knnVectorsReaders, mergeState.docMaps, - knnVectorsReader -> { - return knnVectorsReader.getFloatVectorValues(fieldInfo.name); - }, - (docMap, values) -> { - return new FloatVectorValuesSub(docMap, values); - }), + fieldInfo, + mergeState.fieldInfos, + knnVectorsReader -> knnVectorsReader.getFloatVectorValues(fieldInfo.name), + FloatVectorValuesSub::new), mergeState); } @@ -256,12 +276,10 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { mergeVectorValues( mergeState.knnVectorsReaders, mergeState.docMaps, - knnVectorsReader -> { - return knnVectorsReader.getByteVectorValues(fieldInfo.name); - }, - (docMap, values) -> { - return new ByteVectorValuesSub(docMap, values); - }), + fieldInfo, + mergeState.fieldInfos, + knnVectorsReader -> knnVectorsReader.getByteVectorValues(fieldInfo.name), + ByteVectorValuesSub::new), mergeState); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index 0613c9c82b8..b334298cb8f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -174,6 +174,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { throw new IllegalArgumentException( "field=\"" @@ -197,6 +200,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { throw new IllegalArgumentException( "field=\"" diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index a5cfe0943db..dc0fb7184c7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -17,6 +17,7 @@ package org.apache.lucene.codecs.lucene99; +import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; @@ -353,8 +354,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { : new TaskExecutor(mergeState.intraMergeTaskExecutor), numMergeWorkers); for (int i = 0; i < mergeState.liveDocs.length; i++) { - merger.addReader( - mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) { + merger.addReader( + mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); + } } DocIdSetIterator mergedVectorIterator = null; switch (fieldInfo.getVectorEncoding()) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index c1a5c706549..b8188a43bae 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -165,8 +165,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return null; + if (fieldEntry == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); } final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field); OffHeapQuantizedByteVectorValues quantizedByteVectorValues = diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index e477fec75e5..bb333ad45c2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -17,6 +17,7 @@ package org.apache.lucene.codecs.lucene99; +import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; @@ -630,7 +631,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite IntArrayList segmentSizes = new IntArrayList(mergeState.liveDocs.length); for (int i = 0; i < mergeState.liveDocs.length; i++) { FloatVectorValues fvv; - if (mergeState.knnVectorsReaders[i] != null + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name) && (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null && fvv.size() > 0) { ScalarQuantizer quantizationState = @@ -928,8 +929,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite List subs = new ArrayList<>(); for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { - if (mergeState.knnVectorsReaders[i] != null - && mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name) != null) { + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) { QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name); assert scalarQuantizer != null; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 8b7f538de00..ed4abb6f2c6 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -27,11 +27,14 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnByteVectorField; @@ -53,6 +56,9 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MergePolicy; +import org.apache.lucene.index.MergeScheduler; +import org.apache.lucene.index.MergeTrigger; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentWriteState; @@ -230,6 +236,106 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe } } + public void testMergingWithDifferentKnnFields() throws Exception { + try (var dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + Codec codec = getCodec(); + if (codec.knnVectorsFormat() instanceof PerFieldKnnVectorsFormat perFieldKnnVectorsFormat) { + final KnnVectorsFormat format = + perFieldKnnVectorsFormat.getKnnVectorsFormatForField("field"); + iwc.setCodec( + new FilterCodec(codec.getName(), codec) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return format; + } + }); + } + TestMergeScheduler mergeScheduler = new TestMergeScheduler(); + iwc.setMergeScheduler(mergeScheduler); + iwc.setMergePolicy(new ForceMergePolicy(iwc.getMergePolicy())); + try (var writer = new IndexWriter(dir, iwc)) { + for (int i = 0; i < 10; i++) { + var doc = new Document(); + doc.add(new KnnFloatVectorField("field", new float[] {i, i + 1, i + 2, i + 3})); + writer.addDocument(doc); + } + writer.commit(); + for (int i = 0; i < 10; i++) { + var doc = new Document(); + doc.add(new KnnFloatVectorField("otherVector", new float[] {i, i, i, i})); + writer.addDocument(doc); + } + writer.commit(); + writer.forceMerge(1); + assertNull(mergeScheduler.ex.get()); + } + } + } + + public void testMergingWithDifferentByteKnnFields() throws Exception { + try (var dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + Codec codec = getCodec(); + if (codec.knnVectorsFormat() instanceof PerFieldKnnVectorsFormat perFieldKnnVectorsFormat) { + final KnnVectorsFormat format = + perFieldKnnVectorsFormat.getKnnVectorsFormatForField("field"); + iwc.setCodec( + new FilterCodec(codec.getName(), codec) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return format; + } + }); + } + TestMergeScheduler mergeScheduler = new TestMergeScheduler(); + iwc.setMergeScheduler(mergeScheduler); + iwc.setMergePolicy(new ForceMergePolicy(iwc.getMergePolicy())); + try (var writer = new IndexWriter(dir, iwc)) { + for (int i = 0; i < 10; i++) { + var doc = new Document(); + doc.add( + new KnnByteVectorField("field", new byte[] {(byte) i, (byte) i, (byte) i, (byte) i})); + writer.addDocument(doc); + } + writer.commit(); + for (int i = 0; i < 10; i++) { + var doc = new Document(); + doc.add( + new KnnByteVectorField( + "otherVector", new byte[] {(byte) i, (byte) i, (byte) i, (byte) i})); + writer.addDocument(doc); + } + writer.commit(); + writer.forceMerge(1); + assertNull(mergeScheduler.ex.get()); + } + } + } + + private static final class TestMergeScheduler extends MergeScheduler { + AtomicReference ex = new AtomicReference<>(); + + @Override + public void merge(MergeSource mergeSource, MergeTrigger trigger) throws IOException { + while (true) { + MergePolicy.OneMerge merge = mergeSource.getNextMerge(); + if (merge == null) { + break; + } + try { + mergeSource.merge(merge); + } catch (IllegalStateException | IllegalArgumentException e) { + ex.set(e); + break; + } + } + } + + @Override + public void close() {} + } + @SuppressWarnings("unchecked") public void testWriterRamEstimate() throws Exception { final FieldInfos fieldInfos = new FieldInfos(new FieldInfo[0]);