LUCENE-10375: Write merged vectors to file before building graph (#601)

When merging segments together, the `KnnVectorsWriter` creates a `VectorValues`
instance with a merged view of all the segments' vectors. This merged instance
is used when constructing the new HNSW graph. Graph building needs random
access, and the merged VectorValues support this by mapping from merged
ordinals to segments and segment ordinals. This mapping can add significant
overhead when building the graph.

This change updates the HNSW merging logic to first write the combined segment
vectors to a file, then use that the file to build the graph. This helps speed
up segment merging, and also lets us simplify `VectorValuesMerger`, which
provides the merged view of vector values.
This commit is contained in:
Julie Tibshirani 2022-01-18 13:53:05 -08:00 committed by GitHub
parent 2e2c4818d1
commit dfca9a5608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 264 additions and 231 deletions

View File

@ -163,6 +163,9 @@ Optimizations
* LUCENE-10379: Count directly into the dense values array in FastTaxonomyFacetCounts#countAll.
(Guo Feng, Greg Miller)
* LUCENE-10375: Speed up HNSW vectors merge by first writing combined vector
data to a file. (Julie Tibshirani, Adrien Grand)
Changes in runtime behavior
---------------------

View File

@ -17,19 +17,13 @@
package org.apache.lucene.codecs;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -48,7 +42,11 @@ public abstract class KnnVectorsWriter implements Closeable {
/** Called once at the end before close */
public abstract void finish() throws IOException;
/** Merge the vector values from multiple segments, for all fields */
/**
* Merges the segment vectors for all fields. This default implementation delegates to {@link
* #writeField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
* deleted documents.
*/
public void merge(MergeState mergeState) throws IOException {
for (int i = 0; i < mergeState.fieldInfos.length; i++) {
KnnVectorsReader reader = mergeState.knnVectorsReaders[i];
@ -57,142 +55,97 @@ public abstract class KnnVectorsWriter implements Closeable {
reader.checkIntegrity();
}
}
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
if (fieldInfo.hasVectorValues()) {
mergeVectors(fieldInfo, mergeState);
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
}
writeField(
fieldInfo,
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
throw new UnsupportedOperationException();
}
});
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
}
}
}
finish();
}
private void mergeVectors(FieldInfo mergeFieldInfo, final MergeState mergeState)
throws IOException {
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
}
// Create a new VectorValues by iterating over the sub vectors, mapping the resulting
// docids using docMaps in the mergeState.
writeField(
mergeFieldInfo,
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
List<VectorValuesSub> subs = new ArrayList<>();
int dimension = -1;
VectorSimilarityFunction similarityFunction = null;
int nonEmptySegmentIndex = 0;
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
int segmentDimension = mergeFieldInfo.getVectorDimension();
VectorSimilarityFunction segmentSimilarityFunction =
mergeFieldInfo.getVectorSimilarityFunction();
if (dimension == -1) {
dimension = segmentDimension;
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
} else if (dimension != segmentDimension) {
throw new IllegalStateException(
"Varying dimensions for vector-valued field "
+ mergeFieldInfo.name
+ ": "
+ dimension
+ "!="
+ segmentDimension);
} else if (similarityFunction != segmentSimilarityFunction) {
throw new IllegalStateException(
"Varying similarity functions for vector-valued field "
+ mergeFieldInfo.name
+ ": "
+ similarityFunction
+ "!="
+ segmentSimilarityFunction);
}
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
if (values != null) {
subs.add(
new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
}
}
}
}
return new VectorValuesMerger(subs, mergeState);
}
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
});
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
}
}
/** Tracks state of one sub-reader that we are merging */
private static class VectorValuesSub extends DocIDMerger.Sub {
final VectorValues values;
final int segmentIndex;
int count;
VectorValuesSub(int segmentIndex, MergeState.DocMap docMap, VectorValues values) {
VectorValuesSub(MergeState.DocMap docMap, VectorValues values) {
super(docMap);
this.values = values;
this.segmentIndex = segmentIndex;
assert values.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
int docId = values.nextDoc();
if (docId != NO_MORE_DOCS) {
// Note: this does count deleted docs since they are present in the to-be-merged segment
++count;
}
return docId;
return values.nextDoc();
}
}
/**
* View over multiple VectorValues supporting iterator-style access via DocIdMerger. Maintains a
* reverse ordinal mapping for documents having values in order to support random access by dense
* ordinal.
*/
private static class VectorValuesMerger extends VectorValues
implements RandomAccessVectorValuesProducer {
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
public static class MergedVectorValues extends VectorValues {
private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final int[] ordBase;
private final int cost;
private int size;
private final int size;
private int docId;
private VectorValuesSub current;
/* For each doc with a vector, record its ord in the segments being merged. This enables random
* access into the unmerged segments using the ords from the merged segment.
*/
private int[] ordMap;
private int ord;
VectorValuesMerger(List<VectorValuesSub> subs, MergeState mergeState) throws IOException {
/** Returns a merged view over all the segment's {@link VectorValues}. */
public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
List<VectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
VectorValues values = knnVectorsReader.getVectorValues(fieldInfo.name);
if (values != null) {
subs.add(new VectorValuesSub(mergeState.docMaps[i], values));
}
}
}
return new MergedVectorValues(subs, mergeState);
}
private MergedVectorValues(List<VectorValuesSub> subs, MergeState mergeState)
throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalCost = 0, totalSize = 0;
@ -200,20 +153,8 @@ public abstract class KnnVectorsWriter implements Closeable {
totalCost += sub.values.cost();
totalSize += sub.values.size();
}
/* This size includes deleted docs, but when we iterate over docs here (nextDoc())
* we skip deleted docs. So we sneakily update this size once we observe that iteration is complete.
* That way by the time we are asked to do random access for graph building, we have a correct size.
*/
cost = totalCost;
size = totalSize;
ordMap = new int[size];
ordBase = new int[subs.size()];
int lastBase = 0;
for (int k = 0; k < subs.size(); k++) {
int size = subs.get(k).values.size();
ordBase[k] = lastBase;
lastBase += size;
}
docId = -1;
}
@ -227,12 +168,8 @@ public abstract class KnnVectorsWriter implements Closeable {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
/* update the size to reflect the number of *non-deleted* documents seen so we can support
* random access. */
size = ord;
} else {
docId = current.mappedDocID;
ordMap[ord++] = ordBase[current.segmentIndex] + current.count - 1;
}
return docId;
}
@ -247,11 +184,6 @@ public abstract class KnnVectorsWriter implements Closeable {
return current.values.binaryValue();
}
@Override
public RandomAccessVectorValues randomAccess() {
return new MergerRandomAccess();
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
@ -271,52 +203,5 @@ public abstract class KnnVectorsWriter implements Closeable {
public int dimension() {
return subs.get(0).values.dimension();
}
class MergerRandomAccess implements RandomAccessVectorValues {
private final List<RandomAccessVectorValues> raSubs;
MergerRandomAccess() {
raSubs = new ArrayList<>(subs.size());
for (VectorValuesSub sub : subs) {
if (sub.values instanceof RandomAccessVectorValuesProducer) {
raSubs.add(((RandomAccessVectorValuesProducer) sub.values).randomAccess());
} else {
throw new IllegalStateException(
"Cannot merge VectorValues without support for random access");
}
}
}
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return VectorValuesMerger.this.dimension();
}
@Override
public float[] vectorValue(int target) throws IOException {
int unmappedOrd = ordMap[target];
int segmentOrd = Arrays.binarySearch(ordBase, unmappedOrd);
if (segmentOrd < 0) {
// get the index of the greatest lower bound
segmentOrd = -2 - segmentOrd;
}
while (segmentOrd < ordBase.length - 1 && ordBase[segmentOrd + 1] == ordBase[segmentOrd]) {
// forward over empty segments which will share the same ordBase
segmentOrd++;
}
return raSubs.get(segmentOrd).vectorValue(unmappedOrd - ordBase[segmentOrd]);
}
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
}
}
}

View File

@ -271,7 +271,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new OffHeapVectorValues(fieldEntry, bytesSlice);
return new OffHeapVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice);
}
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
@ -354,10 +354,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
private static class OffHeapVectorValues extends VectorValues
public static class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
final FieldEntry fieldEntry;
final int dimension;
final int[] ordToDoc;
final IndexInput dataIn;
final BytesRef binaryValue;
@ -368,23 +369,25 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
int ord = -1;
int doc = -1;
OffHeapVectorValues(FieldEntry fieldEntry, IndexInput dataIn) {
this.fieldEntry = fieldEntry;
OffHeapVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) {
this.dimension = dimension;
this.ordToDoc = ordToDoc;
this.dataIn = dataIn;
byteSize = Float.BYTES * fieldEntry.dimension;
byteSize = Float.BYTES * dimension;
byteBuffer = ByteBuffer.allocate(byteSize);
value = new float[fieldEntry.dimension];
value = new float[dimension];
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
@Override
public int dimension() {
return fieldEntry.dimension;
return dimension;
}
@Override
public int size() {
return fieldEntry.size();
return ordToDoc.length;
}
@Override
@ -411,7 +414,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
if (++ord >= size()) {
doc = NO_MORE_DOCS;
} else {
doc = fieldEntry.ordToDoc[ord];
doc = ordToDoc[ord];
}
return doc;
}
@ -419,27 +422,27 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public int advance(int target) {
assert docID() < target;
ord = Arrays.binarySearch(fieldEntry.ordToDoc, ord + 1, fieldEntry.ordToDoc.length, target);
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
assert ord <= fieldEntry.ordToDoc.length;
if (ord == fieldEntry.ordToDoc.length) {
assert ord <= ordToDoc.length;
if (ord == ordToDoc.length) {
doc = NO_MORE_DOCS;
} else {
doc = fieldEntry.ordToDoc[ord];
doc = ordToDoc[ord];
}
return doc;
}
@Override
public long cost() {
return fieldEntry.size();
return ordToDoc.length;
}
@Override
public RandomAccessVectorValues randomAccess() {
return new OffHeapVectorValues(fieldEntry, dataIn.clone());
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
}
@Override

View File

@ -26,11 +26,14 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -110,26 +113,14 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
long pos = vectorData.getFilePointer();
// write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when
// CFS is not used, eg for larger indexes
long padding = (4 - (pos & 0x3)) & 0x3;
long vectorDataOffset = pos + padding;
for (int i = 0; i < padding; i++) {
vectorData.writeByte((byte) 0);
}
// TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index
int[] docIds = new int[vectors.size()];
int count = 0;
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
// write vector
writeVectorValue(vectors);
docIds[count] = docV;
}
// count may be < vectors.size() e,g, if some documents were deleted
long[] offsets = new long[count];
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
int[] docIds = writeVectorData(vectorData, vectors);
assert vectors.size() == docIds.length;
long[] offsets = new long[docIds.length];
long vectorIndexOffset = vectorIndex.getFilePointer();
if (vectors instanceof RandomAccessVectorValuesProducer) {
writeGraph(
@ -138,13 +129,14 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
fieldInfo.getVectorSimilarityFunction(),
vectorIndexOffset,
offsets,
count,
maxConn,
beamWidth);
} else {
throw new IllegalArgumentException(
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldInfo,
@ -152,18 +144,132 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
count,
docIds);
writeGraphOffsets(meta, offsets);
}
@Override
public void merge(MergeState mergeState) throws IOException {
for (int i = 0; i < mergeState.fieldInfos.length; i++) {
KnnVectorsReader reader = mergeState.knnVectorsReaders[i];
assert reader != null || mergeState.fieldInfos[i].hasVectorValues() == false;
if (reader != null) {
reader.checkIntegrity();
}
}
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
if (fieldInfo.hasVectorValues()) {
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
}
mergeField(fieldInfo, mergeState);
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
}
}
}
finish();
}
private void mergeField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
}
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
IndexInput vectorDataInput = null;
boolean success = false;
try {
// write the merged vector data to a temporary file
int[] docIds = writeVectorData(tempVectorData, vectors);
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
// copy the temporary file vectors to the actual data file
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), segmentWriteState.context);
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
// build the graph using the temporary vector data
Lucene90HnswVectorsReader.OffHeapVectorValues offHeapVectors =
new Lucene90HnswVectorsReader.OffHeapVectorValues(
vectors.dimension(), docIds, vectorDataInput);
long[] offsets = new long[docIds.length];
long vectorIndexOffset = vectorIndex.getFilePointer();
writeGraph(
vectorIndex,
offHeapVectors,
fieldInfo.getVectorSimilarityFunction(),
vectorIndexOffset,
offsets,
maxConn,
beamWidth);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldInfo,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
docIds);
writeGraphOffsets(meta, offsets);
success = true;
} finally {
IOUtils.close(vectorDataInput);
if (success) {
segmentWriteState.directory.deleteFile(tempVectorData.getName());
} else {
IOUtils.closeWhileHandlingException(tempVectorData);
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempVectorData.getName());
}
}
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
}
}
/**
* Writes the vector values to the output and returns a mapping from dense ordinals to document
* IDs. The length of the returned array matches the total number of documents with a vector
* (which excludes deleted documents), so it may be less than {@link VectorValues#size()}.
*/
private static int[] writeVectorData(IndexOutput output, VectorValues vectors)
throws IOException {
int[] docIds = new int[vectors.size()];
int count = 0;
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
// write vector
BytesRef binaryValue = vectors.binaryValue();
assert binaryValue.length == vectors.dimension() * Float.BYTES;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docIds[count] = docV;
}
if (docIds.length > count) {
return ArrayUtil.copyOfSubArray(docIds, 0, count);
}
return docIds;
}
private void writeMeta(
FieldInfo field,
long vectorDataOffset,
long vectorDataLength,
long indexDataOffset,
long indexDataLength,
int size,
int[] docIds)
throws IOException {
meta.writeInt(field.number);
@ -173,20 +279,13 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
meta.writeVLong(indexDataOffset);
meta.writeVLong(indexDataLength);
meta.writeInt(field.getVectorDimension());
meta.writeInt(size);
for (int i = 0; i < size; i++) {
meta.writeInt(docIds.length);
for (int docId : docIds) {
// TODO: delta-encode, or write as bitset
meta.writeVInt(docIds[i]);
meta.writeVInt(docId);
}
}
private void writeVectorValue(VectorValues vectors) throws IOException {
// write vector value
BytesRef binaryValue = vectors.binaryValue();
assert binaryValue.length == vectors.dimension() * Float.BYTES;
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException {
long last = 0;
for (long offset : offsets) {
@ -201,7 +300,6 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
VectorSimilarityFunction similarityFunction,
long graphDataOffset,
long[] offsets,
int count,
int maxConn,
int beamWidth)
throws IOException {
@ -211,7 +309,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
for (int ord = 0; ord < count; ord++) {
for (int ord = 0; ord < offsets.length; ord++) {
// write graph
offsets[ord] = graphData.getFilePointer() - graphDataOffset;

View File

@ -19,7 +19,10 @@ package org.apache.lucene.codecs.perfield;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.TreeMap;
@ -27,6 +30,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
@ -103,6 +107,31 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
}
@Override
public final void merge(MergeState mergeState) throws IOException {
Map<KnnVectorsWriter, Collection<String>> writersToFields = new IdentityHashMap<>();
// Group each writer by the fields it handles
for (FieldInfo fi : mergeState.mergeFieldInfos) {
if (fi.hasVectorValues() == false) {
continue;
}
KnnVectorsWriter writer = getInstance(fi);
Collection<String> fields = writersToFields.computeIfAbsent(writer, k -> new ArrayList<>());
fields.add(fi.name);
}
// Delegate the merge to the appropriate writer
PerFieldMergeState pfMergeState = new PerFieldMergeState(mergeState);
try {
for (Map.Entry<KnnVectorsWriter, Collection<String>> e : writersToFields.entrySet()) {
e.getKey().merge(pfMergeState.apply(e.getValue()));
}
} finally {
pfMergeState.reset();
}
}
@Override
public void finish() throws IOException {
for (WriterAndSuffix was : formats.values()) {

View File

@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
@ -177,6 +178,14 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
writer.writeField(fieldInfo, knnVectorsReader);
}
@Override
public void merge(MergeState mergeState) throws IOException {
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
fieldsWritten.add(fieldInfo.name);
}
writer.merge(mergeState);
}
@Override
public void finish() throws IOException {
writer.finish();

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
@ -69,6 +70,11 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
delegate.writeField(fieldInfo, knnVectorsReader);
}
@Override
public void merge(MergeState mergeState) throws IOException {
delegate.merge(mergeState);
}
@Override
public void finish() throws IOException {
delegate.finish();