mirror of https://github.com/apache/lucene.git
LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues (#534)
This commit is contained in:
parent
f2e00bb9e0
commit
d0ad9f5bfc
|
@ -50,6 +50,9 @@ API Changes
|
|||
org.apache.lucene.* to org.apache.lucene.tests.* to avoid package name conflicts with the
|
||||
core module. (Dawid Weiss)
|
||||
|
||||
* LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues.
|
||||
(Zach Chen, Michael Sokolov, Julie Tibshirani, Adrien Grand)
|
||||
|
||||
* LUCENE-10335: Deprecate helper methods for resource loading in IOUtils and StopwordAnalyzerBase
|
||||
that are not compatible with module system (Class#getResourceAsStream() and Class#getResource()
|
||||
are caller sensitive in Java 11). Instead add utility method IOUtils#requireResourceNonNull(T)
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -74,7 +75,9 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
long vectorDataOffset = vectorData.getFilePointer();
|
||||
List<Integer> docIds = new ArrayList<>();
|
||||
int docV;
|
||||
|
|
|
@ -31,6 +31,8 @@ import org.apache.lucene.index.RandomAccessVectorValues;
|
|||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Writes vectors to an index. */
|
||||
|
@ -40,7 +42,8 @@ public abstract class KnnVectorsWriter implements Closeable {
|
|||
protected KnnVectorsWriter() {}
|
||||
|
||||
/** Write all values contained in the provided reader */
|
||||
public abstract void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException;
|
||||
public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException;
|
||||
|
||||
/** Called once at the end before close */
|
||||
public abstract void finish() throws IOException;
|
||||
|
@ -67,47 +70,77 @@ public abstract class KnnVectorsWriter implements Closeable {
|
|||
if (mergeState.infoStream.isEnabled("VV")) {
|
||||
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
|
||||
}
|
||||
List<VectorValuesSub> subs = new ArrayList<>();
|
||||
int dimension = -1;
|
||||
VectorSimilarityFunction similarityFunction = null;
|
||||
int nonEmptySegmentIndex = 0;
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
if (knnVectorsReader != null) {
|
||||
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
|
||||
int segmentDimension = mergeFieldInfo.getVectorDimension();
|
||||
VectorSimilarityFunction segmentSimilarityFunction =
|
||||
mergeFieldInfo.getVectorSimilarityFunction();
|
||||
if (dimension == -1) {
|
||||
dimension = segmentDimension;
|
||||
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
|
||||
} else if (dimension != segmentDimension) {
|
||||
throw new IllegalStateException(
|
||||
"Varying dimensions for vector-valued field "
|
||||
+ mergeFieldInfo.name
|
||||
+ ": "
|
||||
+ dimension
|
||||
+ "!="
|
||||
+ segmentDimension);
|
||||
} else if (similarityFunction != segmentSimilarityFunction) {
|
||||
throw new IllegalStateException(
|
||||
"Varying similarity functions for vector-valued field "
|
||||
+ mergeFieldInfo.name
|
||||
+ ": "
|
||||
+ similarityFunction
|
||||
+ "!="
|
||||
+ segmentSimilarityFunction);
|
||||
}
|
||||
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
|
||||
if (values != null) {
|
||||
subs.add(new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create a new VectorValues by iterating over the sub vectors, mapping the resulting
|
||||
// docids using docMaps in the mergeState.
|
||||
writeField(mergeFieldInfo, new VectorValuesMerger(subs, mergeState));
|
||||
writeField(
|
||||
mergeFieldInfo,
|
||||
new KnnVectorsReader() {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
List<VectorValuesSub> subs = new ArrayList<>();
|
||||
int dimension = -1;
|
||||
VectorSimilarityFunction similarityFunction = null;
|
||||
int nonEmptySegmentIndex = 0;
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
if (knnVectorsReader != null) {
|
||||
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
|
||||
int segmentDimension = mergeFieldInfo.getVectorDimension();
|
||||
VectorSimilarityFunction segmentSimilarityFunction =
|
||||
mergeFieldInfo.getVectorSimilarityFunction();
|
||||
if (dimension == -1) {
|
||||
dimension = segmentDimension;
|
||||
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
|
||||
} else if (dimension != segmentDimension) {
|
||||
throw new IllegalStateException(
|
||||
"Varying dimensions for vector-valued field "
|
||||
+ mergeFieldInfo.name
|
||||
+ ": "
|
||||
+ dimension
|
||||
+ "!="
|
||||
+ segmentDimension);
|
||||
} else if (similarityFunction != segmentSimilarityFunction) {
|
||||
throw new IllegalStateException(
|
||||
"Varying similarity functions for vector-valued field "
|
||||
+ mergeFieldInfo.name
|
||||
+ ": "
|
||||
+ similarityFunction
|
||||
+ "!="
|
||||
+ segmentSimilarityFunction);
|
||||
}
|
||||
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
|
||||
if (values != null) {
|
||||
subs.add(
|
||||
new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new VectorValuesMerger(subs, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
});
|
||||
|
||||
if (mergeState.infoStream.isEnabled("VV")) {
|
||||
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -107,7 +108,9 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
long pos = vectorData.getFilePointer();
|
||||
// write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when
|
||||
// CFS is not used, eg for larger indexes
|
||||
|
|
|
@ -98,8 +98,9 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
|
||||
getInstance(fieldInfo).writeField(fieldInfo, values);
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,9 +22,12 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Counter;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
@ -107,13 +110,38 @@ class VectorValuesWriter {
|
|||
* @throws IOException if there is an error writing the field and its values
|
||||
*/
|
||||
public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
|
||||
VectorValues vectorValues =
|
||||
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
|
||||
if (sortMap != null) {
|
||||
knnVectorsWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap));
|
||||
} else {
|
||||
knnVectorsWriter.writeField(fieldInfo, vectorValues);
|
||||
}
|
||||
KnnVectorsReader knnVectorsReader =
|
||||
new KnnVectorsReader() {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
VectorValues vectorValues =
|
||||
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
|
||||
return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
|
||||
}
|
||||
|
||||
static class SortingVectorValues extends VectorValues
|
||||
|
|
|
@ -39,7 +39,6 @@ import org.apache.lucene.index.LeafReader;
|
|||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
|
@ -172,9 +171,10 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
KnnVectorsWriter writer = delegate.fieldsWriter(state);
|
||||
return new KnnVectorsWriter() {
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
fieldsWritten.add(fieldInfo.name);
|
||||
writer.writeField(fieldInfo, values);
|
||||
writer.writeField(fieldInfo, knnVectorsReader);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -58,10 +58,15 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
assert fieldInfo != null;
|
||||
assert values != null;
|
||||
delegate.writeField(fieldInfo, values);
|
||||
assert knnVectorsReader != null;
|
||||
// assert that knnVectorsReader#getVectorValues returns different instances upon repeated
|
||||
// calls
|
||||
assert knnVectorsReader.getVectorValues(fieldInfo.name)
|
||||
!= knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
delegate.writeField(fieldInfo, knnVectorsReader);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue