mirror of https://github.com/apache/lucene.git
in KnnVectorsWriter reduce code duplication w.r.t. MergedVectorValues.merge(Float|Byte)VectorValues (#13539)
Co-authored-by: Vigya Sharma <vigyaspeaks@gmail.com>
This commit is contained in:
parent
cc14555395
commit
c55d664b3e
|
@ -23,6 +23,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.BiFunction;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocIDMerger;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
|
@ -35,6 +36,7 @@ import org.apache.lucene.internal.hppc.IntIntHashMap;
|
|||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.IOFunction;
|
||||
|
||||
/** Writes vectors to an index. */
|
||||
public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||
|
@ -111,11 +113,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
}
|
||||
|
||||
/** Tracks state of one sub-reader that we are merging */
|
||||
private static class VectorValuesSub extends DocIDMerger.Sub {
|
||||
private static class FloatVectorValuesSub extends DocIDMerger.Sub {
|
||||
|
||||
final FloatVectorValues values;
|
||||
|
||||
VectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
|
||||
FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
|
||||
super(docMap);
|
||||
this.values = values;
|
||||
assert values.docID() == -1;
|
||||
|
@ -201,61 +203,81 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
public static final class MergedVectorValues {
|
||||
private MergedVectorValues() {}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
|
||||
public static FloatVectorValues mergeFloatVectorValues(
|
||||
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
private static void validateFieldEncoding(FieldInfo fieldInfo, VectorEncoding expected) {
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
VectorEncoding fieldEncoding = fieldInfo.getVectorEncoding();
|
||||
if (fieldEncoding != expected) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
|
||||
"Cannot merge vectors encoded as [" + fieldEncoding + "] as " + expected);
|
||||
}
|
||||
List<VectorValuesSub> subs = new ArrayList<>();
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
}
|
||||
|
||||
private static <V, S> List<S> mergeVectorValues(
|
||||
KnnVectorsReader[] knnVectorsReaders,
|
||||
MergeState.DocMap[] docMaps,
|
||||
IOFunction<KnnVectorsReader, V> valuesSupplier,
|
||||
BiFunction<MergeState.DocMap, V, S> newSub)
|
||||
throws IOException {
|
||||
List<S> subs = new ArrayList<>();
|
||||
for (int i = 0; i < knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
|
||||
if (knnVectorsReader != null) {
|
||||
FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
V values = valuesSupplier.apply(knnVectorsReader);
|
||||
if (values != null) {
|
||||
subs.add(new VectorValuesSub(mergeState.docMaps[i], values));
|
||||
subs.add(newSub.apply(docMaps[i], values));
|
||||
}
|
||||
}
|
||||
}
|
||||
return new MergedFloat32VectorValues(subs, mergeState);
|
||||
return subs;
|
||||
}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
|
||||
public static FloatVectorValues mergeFloatVectorValues(
|
||||
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
validateFieldEncoding(fieldInfo, VectorEncoding.FLOAT32);
|
||||
return new MergedFloat32VectorValues(
|
||||
mergeVectorValues(
|
||||
mergeState.knnVectorsReaders,
|
||||
mergeState.docMaps,
|
||||
knnVectorsReader -> {
|
||||
return knnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
},
|
||||
(docMap, values) -> {
|
||||
return new FloatVectorValuesSub(docMap, values);
|
||||
}),
|
||||
mergeState);
|
||||
}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link ByteVectorValues}. */
|
||||
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
throws IOException {
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
|
||||
}
|
||||
List<ByteVectorValuesSub> subs = new ArrayList<>();
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
if (knnVectorsReader != null) {
|
||||
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||
if (values != null) {
|
||||
subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
|
||||
}
|
||||
}
|
||||
}
|
||||
return new MergedByteVectorValues(subs, mergeState);
|
||||
validateFieldEncoding(fieldInfo, VectorEncoding.BYTE);
|
||||
return new MergedByteVectorValues(
|
||||
mergeVectorValues(
|
||||
mergeState.knnVectorsReaders,
|
||||
mergeState.docMaps,
|
||||
knnVectorsReader -> {
|
||||
return knnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||
},
|
||||
(docMap, values) -> {
|
||||
return new ByteVectorValuesSub(docMap, values);
|
||||
}),
|
||||
mergeState);
|
||||
}
|
||||
|
||||
static class MergedFloat32VectorValues extends FloatVectorValues {
|
||||
private final List<VectorValuesSub> subs;
|
||||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||
private final List<FloatVectorValuesSub> subs;
|
||||
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
private int docId;
|
||||
VectorValuesSub current;
|
||||
FloatVectorValuesSub current;
|
||||
|
||||
private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
|
||||
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
|
||||
throws IOException {
|
||||
this.subs = subs;
|
||||
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||
int totalSize = 0;
|
||||
for (VectorValuesSub sub : subs) {
|
||||
for (FloatVectorValuesSub sub : subs) {
|
||||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
|
|
Loading…
Reference in New Issue