diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 9e84b7a636b..ea2c6648259 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -50,6 +50,9 @@ API Changes org.apache.lucene.* to org.apache.lucene.tests.* to avoid package name conflicts with the core module. (Dawid Weiss) +* LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues. + (Zach Chen, Michael Sokolov, Julie Tibshirani, Adrien Grand) + * LUCENE-10335: Deprecate helper methods for resource loading in IOUtils and StopwordAnalyzerBase that are not compatible with module system (Class#getResourceAsStream() and Class#getResource() are caller sensitive in Java 11). Instead add utility method IOUtils#requireResourceNonNull(T) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java index 270e9db54cb..8b527e0f27b 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; @@ -74,7 +75,9 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter { } @Override - public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException { + public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader) + throws IOException { + VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name); long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); int docV; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index cd104c41961..4afa933112c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -31,6 +31,8 @@ import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; /** Writes vectors to an index. */ @@ -40,7 +42,8 @@ public abstract class KnnVectorsWriter implements Closeable { protected KnnVectorsWriter() {} /** Write all values contained in the provided reader */ - public abstract void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException; + public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader) + throws IOException; /** Called once at the end before close */ public abstract void finish() throws IOException; @@ -67,47 +70,77 @@ public abstract class KnnVectorsWriter implements Closeable { if (mergeState.infoStream.isEnabled("VV")) { mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo); } - List subs = new ArrayList<>(); - int dimension = -1; - VectorSimilarityFunction similarityFunction = null; - int nonEmptySegmentIndex = 0; - for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { - KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; - if (knnVectorsReader != null) { - if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) { - int segmentDimension = mergeFieldInfo.getVectorDimension(); - VectorSimilarityFunction segmentSimilarityFunction = - mergeFieldInfo.getVectorSimilarityFunction(); - if (dimension == -1) { - dimension = segmentDimension; - similarityFunction = mergeFieldInfo.getVectorSimilarityFunction(); - } else if (dimension != segmentDimension) { - throw new IllegalStateException( - "Varying dimensions for vector-valued field " - + mergeFieldInfo.name - + ": " - + dimension - + "!=" - + segmentDimension); - } else if (similarityFunction != segmentSimilarityFunction) { - throw new IllegalStateException( - "Varying similarity functions for vector-valued field " - + mergeFieldInfo.name - + ": " - + similarityFunction - + "!=" - + segmentSimilarityFunction); - } - VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name); - if (values != null) { - subs.add(new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values)); - } - } - } - } // Create a new VectorValues by iterating over the sub vectors, mapping the resulting // docids using docMaps in the mergeState. - writeField(mergeFieldInfo, new VectorValuesMerger(subs, mergeState)); + writeField( + mergeFieldInfo, + new KnnVectorsReader() { + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public void close() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void checkIntegrity() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public VectorValues getVectorValues(String field) throws IOException { + List subs = new ArrayList<>(); + int dimension = -1; + VectorSimilarityFunction similarityFunction = null; + int nonEmptySegmentIndex = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader != null) { + if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) { + int segmentDimension = mergeFieldInfo.getVectorDimension(); + VectorSimilarityFunction segmentSimilarityFunction = + mergeFieldInfo.getVectorSimilarityFunction(); + if (dimension == -1) { + dimension = segmentDimension; + similarityFunction = mergeFieldInfo.getVectorSimilarityFunction(); + } else if (dimension != segmentDimension) { + throw new IllegalStateException( + "Varying dimensions for vector-valued field " + + mergeFieldInfo.name + + ": " + + dimension + + "!=" + + segmentDimension); + } else if (similarityFunction != segmentSimilarityFunction) { + throw new IllegalStateException( + "Varying similarity functions for vector-valued field " + + mergeFieldInfo.name + + ": " + + similarityFunction + + "!=" + + segmentSimilarityFunction); + } + VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name); + if (values != null) { + subs.add( + new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values)); + } + } + } + } + return new VectorValuesMerger(subs, mergeState); + } + + @Override + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) + throws IOException { + throw new UnsupportedOperationException(); + } + }); + if (mergeState.infoStream.isEnabled("VV")) { mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java index 0c2832bf5cf..f5124078e28 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -22,6 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.util.Arrays; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; @@ -107,7 +108,9 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter { } @Override - public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException { + public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader) + throws IOException { + VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name); long pos = vectorData.getFilePointer(); // write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when // CFS is not used, eg for larger indexes diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 1ec03dac70d..ee2f9313bf7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -98,8 +98,9 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat { } @Override - public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException { - getInstance(fieldInfo).writeField(fieldInfo, values); + public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader) + throws IOException { + getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java index 4b403a34af1..673f39a3404 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java @@ -22,9 +22,12 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.Counter; import org.apache.lucene.util.RamUsageEstimator; @@ -107,13 +110,38 @@ class VectorValuesWriter { * @throws IOException if there is an error writing the field and its values */ public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException { - VectorValues vectorValues = - new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension()); - if (sortMap != null) { - knnVectorsWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap)); - } else { - knnVectorsWriter.writeField(fieldInfo, vectorValues); - } + KnnVectorsReader knnVectorsReader = + new KnnVectorsReader() { + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public void close() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void checkIntegrity() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public VectorValues getVectorValues(String field) throws IOException { + VectorValues vectorValues = + new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension()); + return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues; + } + + @Override + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) + throws IOException { + throw new UnsupportedOperationException(); + } + }; + + knnVectorsWriter.writeField(fieldInfo, knnVectorsReader); } static class SortingVectorValues extends VectorValues diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java index 1181247ae1a..8584cc3d15a 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java @@ -39,7 +39,6 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; @@ -172,9 +171,10 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase { KnnVectorsWriter writer = delegate.fieldsWriter(state); return new KnnVectorsWriter() { @Override - public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException { + public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader) + throws IOException { fieldsWritten.add(fieldInfo.name); - writer.writeField(fieldInfo, values); + writer.writeField(fieldInfo, knnVectorsReader); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index 55b17d5bbbd..a38b19ce28f 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -58,10 +58,15 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat { } @Override - public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException { + public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader) + throws IOException { assert fieldInfo != null; - assert values != null; - delegate.writeField(fieldInfo, values); + assert knnVectorsReader != null; + // assert that knnVectorsReader#getVectorValues returns different instances upon repeated + // calls + assert knnVectorsReader.getVectorValues(fieldInfo.name) + != knnVectorsReader.getVectorValues(fieldInfo.name); + delegate.writeField(fieldInfo, knnVectorsReader); } @Override