LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues (#534)

This commit is contained in:
zacharymorn 2022-01-06 22:14:41 -08:00 committed by GitHub
parent f2e00bb9e0
commit d0ad9f5bfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 57 deletions

View File

@ -50,6 +50,9 @@ API Changes
org.apache.lucene.* to org.apache.lucene.tests.* to avoid package name conflicts with the
core module. (Dawid Weiss)
* LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues.
(Zach Chen, Michael Sokolov, Julie Tibshirani, Adrien Grand)
* LUCENE-10335: Deprecate helper methods for resource loading in IOUtils and StopwordAnalyzerBase
that are not compatible with module system (Class#getResourceAsStream() and Class#getResource()
are caller sensitive in Java 11). Instead add utility method IOUtils#requireResourceNonNull(T)

View File

@ -23,6 +23,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
@ -74,7 +75,9 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
long vectorDataOffset = vectorData.getFilePointer();
List<Integer> docIds = new ArrayList<>();
int docV;

View File

@ -31,6 +31,8 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Writes vectors to an index. */
@ -40,7 +42,8 @@ public abstract class KnnVectorsWriter implements Closeable {
protected KnnVectorsWriter() {}
/** Write all values contained in the provided reader */
public abstract void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException;
public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException;
/** Called once at the end before close */
public abstract void finish() throws IOException;
@ -67,47 +70,77 @@ public abstract class KnnVectorsWriter implements Closeable {
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
}
List<VectorValuesSub> subs = new ArrayList<>();
int dimension = -1;
VectorSimilarityFunction similarityFunction = null;
int nonEmptySegmentIndex = 0;
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
int segmentDimension = mergeFieldInfo.getVectorDimension();
VectorSimilarityFunction segmentSimilarityFunction =
mergeFieldInfo.getVectorSimilarityFunction();
if (dimension == -1) {
dimension = segmentDimension;
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
} else if (dimension != segmentDimension) {
throw new IllegalStateException(
"Varying dimensions for vector-valued field "
+ mergeFieldInfo.name
+ ": "
+ dimension
+ "!="
+ segmentDimension);
} else if (similarityFunction != segmentSimilarityFunction) {
throw new IllegalStateException(
"Varying similarity functions for vector-valued field "
+ mergeFieldInfo.name
+ ": "
+ similarityFunction
+ "!="
+ segmentSimilarityFunction);
}
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
if (values != null) {
subs.add(new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
}
}
}
}
// Create a new VectorValues by iterating over the sub vectors, mapping the resulting
// docids using docMaps in the mergeState.
writeField(mergeFieldInfo, new VectorValuesMerger(subs, mergeState));
writeField(
mergeFieldInfo,
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
List<VectorValuesSub> subs = new ArrayList<>();
int dimension = -1;
VectorSimilarityFunction similarityFunction = null;
int nonEmptySegmentIndex = 0;
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
int segmentDimension = mergeFieldInfo.getVectorDimension();
VectorSimilarityFunction segmentSimilarityFunction =
mergeFieldInfo.getVectorSimilarityFunction();
if (dimension == -1) {
dimension = segmentDimension;
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
} else if (dimension != segmentDimension) {
throw new IllegalStateException(
"Varying dimensions for vector-valued field "
+ mergeFieldInfo.name
+ ": "
+ dimension
+ "!="
+ segmentDimension);
} else if (similarityFunction != segmentSimilarityFunction) {
throw new IllegalStateException(
"Varying similarity functions for vector-valued field "
+ mergeFieldInfo.name
+ ": "
+ similarityFunction
+ "!="
+ segmentSimilarityFunction);
}
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
if (values != null) {
subs.add(
new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
}
}
}
}
return new VectorValuesMerger(subs, mergeState);
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
});
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
}

View File

@ -22,6 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
@ -107,7 +108,9 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
long pos = vectorData.getFilePointer();
// write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when
// CFS is not used, eg for larger indexes

View File

@ -98,8 +98,9 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
getInstance(fieldInfo).writeField(fieldInfo, values);
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
}
@Override

View File

@ -22,9 +22,12 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Counter;
import org.apache.lucene.util.RamUsageEstimator;
@ -107,13 +110,38 @@ class VectorValuesWriter {
* @throws IOException if there is an error writing the field and its values
*/
public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
if (sortMap != null) {
knnVectorsWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap));
} else {
knnVectorsWriter.writeField(fieldInfo, vectorValues);
}
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
};
knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
}
static class SortingVectorValues extends VectorValues

View File

@ -39,7 +39,6 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
@ -172,9 +171,10 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
KnnVectorsWriter writer = delegate.fieldsWriter(state);
return new KnnVectorsWriter() {
@Override
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
fieldsWritten.add(fieldInfo.name);
writer.writeField(fieldInfo, values);
writer.writeField(fieldInfo, knnVectorsReader);
}
@Override

View File

@ -58,10 +58,15 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
assert fieldInfo != null;
assert values != null;
delegate.writeField(fieldInfo, values);
assert knnVectorsReader != null;
// assert that knnVectorsReader#getVectorValues returns different instances upon repeated
// calls
assert knnVectorsReader.getVectorValues(fieldInfo.name)
!= knnVectorsReader.getVectorValues(fieldInfo.name);
delegate.writeField(fieldInfo, knnVectorsReader);
}
@Override