in KnnVectorsWriter reduce code duplication w.r.t. MergedVectorValues.merge(Float|Byte)VectorValues (#13539)

Co-authored-by: Vigya Sharma <vigyaspeaks@gmail.com>
This commit is contained in:
Christine Poerschke 2024-07-12 10:48:10 +01:00 committed by GitHub
parent cc14555395
commit c55d664b3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 56 additions and 34 deletions

View File

@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet;
@ -35,6 +36,7 @@ import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOFunction;
/** Writes vectors to an index. */
public abstract class KnnVectorsWriter implements Accountable, Closeable {
@ -111,11 +113,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
}
/** Tracks state of one sub-reader that we are merging */
private static class VectorValuesSub extends DocIDMerger.Sub {
private static class FloatVectorValuesSub extends DocIDMerger.Sub {
final FloatVectorValues values;
VectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
@ -201,61 +203,81 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
public static final class MergedVectorValues {
private MergedVectorValues() {}
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
public static FloatVectorValues mergeFloatVectorValues(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
private static void validateFieldEncoding(FieldInfo fieldInfo, VectorEncoding expected) {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
VectorEncoding fieldEncoding = fieldInfo.getVectorEncoding();
if (fieldEncoding != expected) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
"Cannot merge vectors encoded as [" + fieldEncoding + "] as " + expected);
}
List<VectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
}
private static <V, S> List<S> mergeVectorValues(
KnnVectorsReader[] knnVectorsReaders,
MergeState.DocMap[] docMaps,
IOFunction<KnnVectorsReader, V> valuesSupplier,
BiFunction<MergeState.DocMap, V, S> newSub)
throws IOException {
List<S> subs = new ArrayList<>();
for (int i = 0; i < knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
if (knnVectorsReader != null) {
FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
V values = valuesSupplier.apply(knnVectorsReader);
if (values != null) {
subs.add(new VectorValuesSub(mergeState.docMaps[i], values));
subs.add(newSub.apply(docMaps[i], values));
}
}
}
return new MergedFloat32VectorValues(subs, mergeState);
return subs;
}
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
public static FloatVectorValues mergeFloatVectorValues(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
validateFieldEncoding(fieldInfo, VectorEncoding.FLOAT32);
return new MergedFloat32VectorValues(
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getFloatVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new FloatVectorValuesSub(docMap, values);
}),
mergeState);
}
/** Returns a merged view over all the segment's {@link ByteVectorValues}. */
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
}
List<ByteVectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
if (values != null) {
subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
}
}
}
return new MergedByteVectorValues(subs, mergeState);
validateFieldEncoding(fieldInfo, VectorEncoding.BYTE);
return new MergedByteVectorValues(
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getByteVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new ByteVectorValuesSub(docMap, values);
}),
mergeState);
}
static class MergedFloat32VectorValues extends FloatVectorValues {
private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final List<FloatVectorValuesSub> subs;
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
private final int size;
private int docId;
VectorValuesSub current;
FloatVectorValuesSub current;
private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (VectorValuesSub sub : subs) {
for (FloatVectorValuesSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;