mirror of https://github.com/apache/lucene.git
LUCENE-10592 Build HNSW Graph on indexing (#992)
Currently, when indexing knn vectors, we buffer them in memory and on flush during a segment construction we build an HNSW graph. As building an HNSW graph is very expensive, this makes flush operation take a lot of time. This also makes overall indexing performance quite unpredictable – some indexing operations return almost instantly while others that trigger flush take a lot of time. This happens because flushes are unpredictable and trigged by memory used, presence of concurrent searches etc. Building an HNSW graph as we index documents avoid these problems, as the load of HNSW graph construction is spread evenly during indexing. Co-authored-by: Adrien Grand <jpountz@gmail.com>
This commit is contained in:
parent
bd360f9b3e
commit
ba4bc04271
|
@ -23,7 +23,7 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
|
@ -41,7 +41,7 @@ import org.apache.lucene.util.IOUtils;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
||||
public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||
|
||||
private final SegmentWriteState segmentWriteState;
|
||||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
|
@ -55,7 +55,6 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
|||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
segmentWriteState = state;
|
||||
|
||||
String metaFileName =
|
||||
|
@ -107,7 +106,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
|
|
|
@ -23,7 +23,7 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -43,11 +43,10 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
||||
public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||
|
||||
private final SegmentWriteState segmentWriteState;
|
||||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
private final int maxDoc;
|
||||
|
||||
private final int maxConn;
|
||||
private final int beamWidth;
|
||||
|
@ -58,7 +57,6 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
segmentWriteState = state;
|
||||
|
||||
String metaFileName =
|
||||
|
@ -101,7 +99,6 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
Lucene91HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
maxDoc = state.segmentInfo.maxDoc();
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
|
@ -111,7 +108,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
|
@ -149,6 +146,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
writeMeta(
|
||||
fieldInfo,
|
||||
maxDoc,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
|
@ -186,6 +184,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
private void writeMeta(
|
||||
FieldInfo field,
|
||||
int maxDoc,
|
||||
long vectorDataOffset,
|
||||
long vectorDataLength,
|
||||
long vectorIndexOffset,
|
||||
|
|
|
@ -24,8 +24,8 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -49,11 +49,10 @@ import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
||||
public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||
|
||||
private final SegmentWriteState segmentWriteState;
|
||||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
private final int maxDoc;
|
||||
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
|
@ -63,7 +62,6 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
segmentWriteState = state;
|
||||
|
||||
String metaFileName =
|
||||
|
@ -106,7 +104,6 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
Lucene92HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
maxDoc = state.segmentInfo.maxDoc();
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
|
@ -116,7 +113,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
|
@ -155,6 +152,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
writeMeta(
|
||||
fieldInfo,
|
||||
maxDoc,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
|
@ -192,6 +190,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
private void writeMeta(
|
||||
FieldInfo field,
|
||||
int maxDoc,
|
||||
long vectorDataOffset,
|
||||
long vectorDataLength,
|
||||
long vectorIndexOffset,
|
||||
|
|
|
@ -24,7 +24,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
|
@ -35,7 +35,7 @@ import org.apache.lucene.util.BytesRefBuilder;
|
|||
import org.apache.lucene.util.IOUtils;
|
||||
|
||||
/** Writes vector-valued fields in a plain text format */
|
||||
public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
|
||||
public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
|
||||
|
||||
static final BytesRef FIELD_NUMBER = new BytesRef("field-number ");
|
||||
static final BytesRef FIELD_NAME = new BytesRef("field-name ");
|
||||
|
@ -48,8 +48,6 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
|
|||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||
|
||||
SimpleTextKnnVectorsWriter(SegmentWriteState state) throws IOException {
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
|
||||
boolean success = false;
|
||||
// exception handling to pass TestSimpleTextKnnVectorsFormat#testRandomExceptions
|
||||
try {
|
||||
|
@ -75,7 +73,7 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
|
||||
throws IOException {
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
long vectorDataOffset = vectorData.getFilePointer();
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
|
||||
/** Vectors' writer for a field */
|
||||
public abstract class KnnFieldVectorsWriter implements Accountable {
|
||||
|
||||
/** Sole constructor */
|
||||
protected KnnFieldVectorsWriter() {}
|
||||
|
||||
/**
|
||||
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
||||
* increasing order.
|
||||
*/
|
||||
public abstract void addValue(int docID, float[] vectorValue) throws IOException;
|
||||
}
|
|
@ -24,30 +24,44 @@ import java.util.List;
|
|||
import org.apache.lucene.index.DocIDMerger;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Writes vectors to an index. */
|
||||
public abstract class KnnVectorsWriter implements Closeable {
|
||||
public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||
|
||||
/** Sole constructor */
|
||||
protected KnnVectorsWriter() {}
|
||||
|
||||
/** Write all values contained in the provided reader */
|
||||
public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException;
|
||||
/** Add new field for indexing */
|
||||
public abstract KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException;
|
||||
|
||||
/** Flush all buffered data on disk * */
|
||||
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
|
||||
|
||||
/** Write field for merging */
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
KnnFieldVectorsWriter writer = addField(fieldInfo);
|
||||
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedValues.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedValues.nextDoc()) {
|
||||
writer.addValue(doc, mergedValues.vectorValue());
|
||||
}
|
||||
}
|
||||
|
||||
/** Called once at the end before close */
|
||||
public abstract void finish() throws IOException;
|
||||
|
||||
/**
|
||||
* Merges the segment vectors for all fields. This default implementation delegates to {@link
|
||||
* #writeField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
|
||||
* #mergeOneField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
|
||||
* deleted documents.
|
||||
*/
|
||||
public void merge(MergeState mergeState) throws IOException {
|
||||
public final void merge(MergeState mergeState) throws IOException {
|
||||
for (int i = 0; i < mergeState.fieldInfos.length; i++) {
|
||||
KnnVectorsReader reader = mergeState.knnVectorsReaders[i];
|
||||
assert reader != null || mergeState.fieldInfos[i].hasVectorValues() == false;
|
||||
|
@ -62,35 +76,7 @@ public abstract class KnnVectorsWriter implements Closeable {
|
|||
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
|
||||
}
|
||||
|
||||
writeField(
|
||||
fieldInfo,
|
||||
new KnnVectorsReader() {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
});
|
||||
mergeOneField(fieldInfo, mergeState);
|
||||
|
||||
if (mergeState.infoStream.isEnabled("VV")) {
|
||||
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
|
||||
|
@ -118,7 +104,7 @@ public abstract class KnnVectorsWriter implements Closeable {
|
|||
}
|
||||
|
||||
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
|
||||
private static class MergedVectorValues extends VectorValues {
|
||||
protected static class MergedVectorValues extends VectorValues {
|
||||
private final List<VectorValuesSub> subs;
|
||||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||
private final int cost;
|
||||
|
@ -128,7 +114,7 @@ public abstract class KnnVectorsWriter implements Closeable {
|
|||
private VectorValuesSub current;
|
||||
|
||||
/** Returns a merged view over all the segment's {@link VectorValues}. */
|
||||
static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
throws IOException {
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
|
||||
|
|
|
@ -21,23 +21,32 @@ import static org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat.DIRECT
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
|
@ -53,19 +62,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
private final SegmentWriteState segmentWriteState;
|
||||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
private final int maxDoc;
|
||||
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
|
||||
private final List<FieldWriter> fields = new ArrayList<>();
|
||||
private boolean finished;
|
||||
|
||||
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
|
||||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
segmentWriteState = state;
|
||||
|
||||
String metaFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name, state.segmentSuffix, Lucene94HnswVectorsFormat.META_EXTENSION);
|
||||
|
@ -106,7 +112,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
Lucene94HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
maxDoc = state.segmentInfo.maxDoc();
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
|
@ -116,10 +121,231 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter newField = new FieldWriter(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter field : fields) {
|
||||
if (sortMap == null) {
|
||||
writeField(field, maxDoc);
|
||||
} else {
|
||||
writeSortingField(field, maxDoc, sortMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
if (finished) {
|
||||
throw new IllegalStateException("already finished");
|
||||
}
|
||||
finished = true;
|
||||
|
||||
if (meta != null) {
|
||||
// write end of fields marker
|
||||
meta.writeInt(-1);
|
||||
CodecUtil.writeFooter(meta);
|
||||
}
|
||||
if (vectorData != null) {
|
||||
CodecUtil.writeFooter(vectorData);
|
||||
CodecUtil.writeFooter(vectorIndex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = 0;
|
||||
for (FieldWriter field : fields) {
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
|
||||
// write vector values
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (float[] vector : fieldData.vectors) {
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||
writeGraph(graph);
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
||||
writeMeta(
|
||||
fieldData.fieldInfo,
|
||||
maxDoc,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
fieldData.docsWithField,
|
||||
graph);
|
||||
}
|
||||
|
||||
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
final int[] docIdOffsets = new int[sortMap.size()];
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
|
||||
for (int docID = iterator.nextDoc();
|
||||
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||
docID = iterator.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
||||
final int[] ordMap = new int[offset - 1]; // new ord to old ord
|
||||
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
|
||||
int ord = 0;
|
||||
int doc = 0;
|
||||
for (int docIdOffset : docIdOffsets) {
|
||||
if (docIdOffset != 0) {
|
||||
ordMap[ord] = docIdOffset - 1;
|
||||
oldOrdMap[docIdOffset - 1] = ord;
|
||||
newDocsWithField.add(doc);
|
||||
ord++;
|
||||
}
|
||||
doc++;
|
||||
}
|
||||
|
||||
// write vector values
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (int ordinal : ordMap) {
|
||||
float[] vector = fieldData.vectors.get(ordinal);
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||
HnswGraph mockGraph = reconstructAndWriteGraph(graph, ordMap, oldOrdMap);
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
||||
writeMeta(
|
||||
fieldData.fieldInfo,
|
||||
maxDoc,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
newDocsWithField,
|
||||
mockGraph);
|
||||
}
|
||||
|
||||
// reconstruct graph substituting old ordinals with new ordinals
|
||||
private HnswGraph reconstructAndWriteGraph(
|
||||
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
|
||||
if (graph == null) return null;
|
||||
|
||||
List<int[]> nodesByLevel = new ArrayList<>(graph.numLevels());
|
||||
nodesByLevel.add(null);
|
||||
|
||||
int maxOrd = graph.size();
|
||||
int maxConnOnLevel = M * 2;
|
||||
NodesIterator nodesOnLevel0 = graph.getNodesOnLevel(0);
|
||||
while (nodesOnLevel0.hasNext()) {
|
||||
int node = nodesOnLevel0.nextInt();
|
||||
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
|
||||
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxConnOnLevel, maxOrd);
|
||||
}
|
||||
|
||||
maxConnOnLevel = M;
|
||||
for (int level = 1; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
int[] newNodes = new int[nodesOnLevel.size()];
|
||||
int n = 0;
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
|
||||
}
|
||||
Arrays.sort(newNodes);
|
||||
nodesByLevel.add(newNodes);
|
||||
for (int node : newNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
|
||||
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxConnOnLevel, maxOrd);
|
||||
}
|
||||
}
|
||||
return new HnswGraph() {
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
throw new UnsupportedOperationException("Not supported on a mock graph");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int target) {
|
||||
throw new UnsupportedOperationException("Not supported on a mock graph");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graph.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return graph.numLevels();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() {
|
||||
throw new UnsupportedOperationException("Not supported on a mock graph");
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return graph.getNodesOnLevel(0);
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void reconstructAndWriteNeigbours(
|
||||
NeighborArray neighbors, int[] oldToNewMap, int maxConnOnLevel, int maxOrd)
|
||||
throws IOException {
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
|
||||
// Destructively modify; it's ok we are discarding it after this
|
||||
int[] nnodes = neighbors.node();
|
||||
for (int i = 0; i < size; i++) {
|
||||
nnodes[i] = oldToNewMap[nnodes[i]];
|
||||
}
|
||||
Arrays.sort(nnodes, 0, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
int nnode = nnodes[i];
|
||||
assert nnode < maxOrd : "node too large: " + nnode + ">=" + maxOrd;
|
||||
vectorIndex.writeInt(nnode);
|
||||
}
|
||||
// if number of connections < maxConn,
|
||||
// add bogus values up to maxConn to have predictable offsets
|
||||
for (int i = size; i < maxConnOnLevel; i++) {
|
||||
vectorIndex.writeInt(0);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
|
||||
IndexOutput tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
|
@ -148,13 +374,24 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
OffHeapVectorValues offHeapVectors =
|
||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput);
|
||||
OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
||||
OnHeapHnswGraph graph = null;
|
||||
if (offHeapVectors.size() != 0) {
|
||||
// build graph
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
offHeapVectors,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(offHeapVectors.randomAccess());
|
||||
writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
writeMeta(
|
||||
fieldInfo,
|
||||
segmentWriteState.segmentInfo.maxDoc(),
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
|
@ -174,30 +411,44 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
|
||||
throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = vectors.binaryValue();
|
||||
assert binaryValue.length == vectors.dimension() * Float.BYTES;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
private void writeGraph(OnHeapHnswGraph graph) throws IOException {
|
||||
if (graph == null) return;
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
// Destructively modify; it's ok we are discarding it after this
|
||||
int[] nnodes = neighbors.node();
|
||||
Arrays.sort(nnodes, 0, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
int nnode = nnodes[i];
|
||||
assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
|
||||
vectorIndex.writeInt(nnode);
|
||||
}
|
||||
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
|
||||
// offsets
|
||||
for (int i = size; i < maxConnOnLevel; i++) {
|
||||
vectorIndex.writeInt(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
FieldInfo field,
|
||||
int maxDoc,
|
||||
long vectorDataOffset,
|
||||
long vectorDataLength,
|
||||
long vectorIndexOffset,
|
||||
long vectorIndexLength,
|
||||
DocsWithFieldSet docsWithField,
|
||||
OnHeapHnswGraph graph)
|
||||
HnswGraph graph)
|
||||
throws IOException {
|
||||
meta.writeInt(field.number);
|
||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||
|
@ -266,65 +517,129 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
// Destructively modify; it's ok we are discarding it after this
|
||||
int[] nnodes = neighbors.node();
|
||||
Arrays.sort(nnodes, 0, size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
int nnode = nnodes[i];
|
||||
assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
|
||||
vectorIndex.writeInt(nnode);
|
||||
}
|
||||
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
|
||||
// offsets
|
||||
for (int i = size; i < maxConnOnLevel; i++) {
|
||||
vectorIndex.writeInt(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
if (finished) {
|
||||
throw new IllegalStateException("already finished");
|
||||
}
|
||||
finished = true;
|
||||
|
||||
if (meta != null) {
|
||||
// write end of fields marker
|
||||
meta.writeInt(-1);
|
||||
CodecUtil.writeFooter(meta);
|
||||
}
|
||||
if (vectorData != null) {
|
||||
CodecUtil.writeFooter(vectorData);
|
||||
CodecUtil.writeFooter(vectorIndex);
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = vectors.binaryValue();
|
||||
assert binaryValue.length == vectors.dimension() * Float.BYTES;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(meta, vectorData, vectorIndex);
|
||||
}
|
||||
|
||||
private static class FieldWriter extends KnnFieldVectorsWriter {
|
||||
private final FieldInfo fieldInfo;
|
||||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<float[]> vectors;
|
||||
private final RAVectorValues raVectorValues;
|
||||
private final HnswGraphBuilder hnswGraphBuilder;
|
||||
|
||||
private int lastDocID = -1;
|
||||
private int node = 0;
|
||||
|
||||
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
throws IOException {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
raVectorValues = new RAVectorValues(vectors, dim);
|
||||
hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
() -> raVectorValues,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(infoStream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, float[] vectorValue) throws IOException {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
if (vectorValue.length != dim) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to index a vector of dimension "
|
||||
+ vectorValue.length
|
||||
+ " but \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has dimension "
|
||||
+ dim);
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
||||
if (node > 0) {
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||
}
|
||||
node++;
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
OnHeapHnswGraph getGraph() {
|
||||
if (vectors.size() > 0) {
|
||||
return hnswGraphBuilder.getGraph();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
if (vectors.size() == 0) return 0;
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ vectors.size() * vectors.get(0).length * Float.BYTES
|
||||
+ hnswGraphBuilder.getGraph().ramBytesUsed();
|
||||
}
|
||||
}
|
||||
|
||||
private static class RAVectorValues implements RandomAccessVectorValues {
|
||||
private final List<float[]> vectors;
|
||||
private final int dim;
|
||||
|
||||
RAVectorValues(List<float[]> vectors, int dim) {
|
||||
this.vectors = vectors;
|
||||
this.dim = dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,13 +19,11 @@ package org.apache.lucene.codecs.perfield;
|
|||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.ServiceLoader;
|
||||
import java.util.TreeMap;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
|
@ -33,6 +31,7 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
|
@ -102,34 +101,21 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
KnnVectorsWriter writer = getInstance(fieldInfo);
|
||||
return writer.addField(fieldInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void merge(MergeState mergeState) throws IOException {
|
||||
Map<KnnVectorsWriter, Collection<String>> writersToFields = new IdentityHashMap<>();
|
||||
|
||||
// Group each writer by the fields it handles
|
||||
for (FieldInfo fi : mergeState.mergeFieldInfos) {
|
||||
if (fi.hasVectorValues() == false) {
|
||||
continue;
|
||||
}
|
||||
KnnVectorsWriter writer = getInstance(fi);
|
||||
Collection<String> fields = writersToFields.computeIfAbsent(writer, k -> new ArrayList<>());
|
||||
fields.add(fi.name);
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (WriterAndSuffix was : formats.values()) {
|
||||
was.writer.flush(maxDoc, sortMap);
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate the merge to the appropriate writer
|
||||
PerFieldMergeState pfMergeState = new PerFieldMergeState(mergeState);
|
||||
try {
|
||||
for (Map.Entry<KnnVectorsWriter, Collection<String>> e : writersToFields.entrySet()) {
|
||||
e.getKey().merge(pfMergeState.apply(e.getValue()));
|
||||
}
|
||||
} finally {
|
||||
pfMergeState.reset();
|
||||
}
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
getInstance(fieldInfo).mergeOneField(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -180,10 +166,18 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
assert suffixes.containsKey(formatName);
|
||||
suffix = writerAndSuffix.suffix;
|
||||
}
|
||||
|
||||
field.putAttribute(PER_FIELD_SUFFIX_KEY, Integer.toString(suffix));
|
||||
return writerAndSuffix.writer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = 0;
|
||||
for (WriterAndSuffix was : formats.values()) {
|
||||
total += was.writer.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
}
|
||||
|
||||
/** VectorReader that can wrap multiple delegate readers, selected by field. */
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
* Buffers up pending vector value(s) per doc, then flushes when segment flushes. Used for {@code
|
||||
* SimpleTextKnnVectorsWriter} and for vectors writers before v 9.3 .
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
private final List<FieldWriter> fields = new ArrayList<>();
|
||||
|
||||
/** Sole constructor */
|
||||
protected BufferingKnnVectorsWriter() {}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter newField = new FieldWriter(fieldInfo);
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter fieldData : fields) {
|
||||
KnnVectorsReader knnVectorsReader =
|
||||
new KnnVectorsReader() {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
VectorValues vectorValues =
|
||||
new BufferedVectorValues(
|
||||
fieldData.docsWithField,
|
||||
fieldData.vectors,
|
||||
fieldData.fieldInfo.getVectorDimension());
|
||||
return sortMap != null
|
||||
? new VectorValues.SortingVectorValues(vectorValues, sortMap)
|
||||
: vectorValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = 0;
|
||||
for (FieldWriter field : fields) {
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader =
|
||||
new KnnVectorsReader() {
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {}
|
||||
};
|
||||
writeField(fieldInfo, knnVectorsReader, mergeState.segmentInfo.maxDoc());
|
||||
}
|
||||
|
||||
/** Write the provided field */
|
||||
protected abstract void writeField(
|
||||
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
|
||||
|
||||
private static class FieldWriter extends KnnFieldVectorsWriter {
|
||||
private final FieldInfo fieldInfo;
|
||||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<float[]> vectors;
|
||||
|
||||
private int lastDocID = -1;
|
||||
|
||||
public FieldWriter(FieldInfo fieldInfo) {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, float[] vectorValue) {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
if (vectorValue.length != dim) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to index a vector of dimension "
|
||||
+ vectorValue.length
|
||||
+ " but \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has dimension "
|
||||
+ dim);
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
if (vectors.size() == 0) return 0;
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ vectors.size() * vectors.get(0).length * Float.BYTES;
|
||||
}
|
||||
}
|
||||
|
||||
private static class BufferedVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
|
||||
final List<float[]> vectors;
|
||||
final int dimension;
|
||||
|
||||
final ByteBuffer buffer;
|
||||
final BytesRef binaryValue;
|
||||
final ByteBuffer raBuffer;
|
||||
final BytesRef raBinaryValue;
|
||||
|
||||
DocIdSetIterator docsWithFieldIter;
|
||||
int ord = -1;
|
||||
|
||||
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
|
||||
this.docsWithField = docsWithField;
|
||||
this.vectors = vectors;
|
||||
this.dimension = dimension;
|
||||
buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
binaryValue = new BytesRef(buffer.array());
|
||||
raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
raBinaryValue = new BytesRef(raBuffer.array());
|
||||
docsWithFieldIter = docsWithField.iterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
return new BufferedVectorValues(docsWithField, vectors, dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() {
|
||||
buffer.asFloatBuffer().put(vectorValue());
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
raBuffer.asFloatBuffer().put(vectors.get(targetOrd));
|
||||
return raBinaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
return vectors.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docsWithFieldIter.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int docID = docsWithFieldIter.nextDoc();
|
||||
if (docID != NO_MORE_DOCS) {
|
||||
++ord;
|
||||
}
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return docsWithFieldIter.cost();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -30,8 +30,7 @@ import org.apache.lucene.analysis.Analyzer;
|
|||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.codecs.DocValuesConsumer;
|
||||
import org.apache.lucene.codecs.DocValuesFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.NormsConsumer;
|
||||
import org.apache.lucene.codecs.NormsFormat;
|
||||
import org.apache.lucene.codecs.NormsProducer;
|
||||
|
@ -68,6 +67,7 @@ final class IndexingChain implements Accountable {
|
|||
final ByteBlockPool docValuesBytePool;
|
||||
// Writes stored fields
|
||||
final StoredFieldsConsumer storedFieldsConsumer;
|
||||
final VectorValuesConsumer vectorValuesConsumer;
|
||||
final TermVectorsConsumer termVectorsWriter;
|
||||
|
||||
// NOTE: I tried using Hash Map<String,PerField>
|
||||
|
@ -104,6 +104,8 @@ final class IndexingChain implements Accountable {
|
|||
this.fieldInfos = fieldInfos;
|
||||
this.infoStream = indexWriterConfig.getInfoStream();
|
||||
this.abortingExceptionConsumer = abortingExceptionConsumer;
|
||||
this.vectorValuesConsumer =
|
||||
new VectorValuesConsumer(indexWriterConfig.getCodec(), directory, segmentInfo, infoStream);
|
||||
|
||||
if (segmentInfo.getIndexSort() == null) {
|
||||
storedFieldsConsumer =
|
||||
|
@ -262,7 +264,7 @@ final class IndexingChain implements Accountable {
|
|||
}
|
||||
|
||||
t0 = System.nanoTime();
|
||||
writeVectors(state, sortMap);
|
||||
vectorValuesConsumer.flush(state, sortMap);
|
||||
if (infoStream.isEnabled("IW")) {
|
||||
infoStream.message("IW", ((System.nanoTime() - t0) / 1000000) + " msec to write vectors");
|
||||
}
|
||||
|
@ -428,63 +430,6 @@ final class IndexingChain implements Accountable {
|
|||
}
|
||||
}
|
||||
|
||||
/** Writes all buffered vectors. */
|
||||
private void writeVectors(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
|
||||
KnnVectorsWriter knnVectorsWriter = null;
|
||||
boolean success = false;
|
||||
try {
|
||||
for (int i = 0; i < fieldHash.length; i++) {
|
||||
PerField perField = fieldHash[i];
|
||||
while (perField != null) {
|
||||
if (perField.vectorValuesWriter != null) {
|
||||
if (perField.fieldInfo.getVectorDimension() == 0) {
|
||||
// BUG
|
||||
throw new AssertionError(
|
||||
"segment="
|
||||
+ state.segmentInfo
|
||||
+ ": field=\""
|
||||
+ perField.fieldInfo.name
|
||||
+ "\" has no vectors but wrote them");
|
||||
}
|
||||
if (knnVectorsWriter == null) {
|
||||
// lazy init
|
||||
KnnVectorsFormat fmt = state.segmentInfo.getCodec().knnVectorsFormat();
|
||||
if (fmt == null) {
|
||||
throw new IllegalStateException(
|
||||
"field=\""
|
||||
+ perField.fieldInfo.name
|
||||
+ "\" was indexed as vectors but codec does not support vectors");
|
||||
}
|
||||
knnVectorsWriter = fmt.fieldsWriter(state);
|
||||
}
|
||||
|
||||
perField.vectorValuesWriter.flush(sortMap, knnVectorsWriter);
|
||||
perField.vectorValuesWriter = null;
|
||||
} else if (perField.fieldInfo != null && perField.fieldInfo.getVectorDimension() != 0) {
|
||||
// BUG
|
||||
throw new AssertionError(
|
||||
"segment="
|
||||
+ state.segmentInfo
|
||||
+ ": field=\""
|
||||
+ perField.fieldInfo.name
|
||||
+ "\" has vectors but did not write them");
|
||||
}
|
||||
perField = perField.next;
|
||||
}
|
||||
}
|
||||
if (knnVectorsWriter != null) {
|
||||
knnVectorsWriter.finish();
|
||||
}
|
||||
success = true;
|
||||
} finally {
|
||||
if (success) {
|
||||
IOUtils.close(knnVectorsWriter);
|
||||
} else {
|
||||
IOUtils.closeWhileHandlingException(knnVectorsWriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void writeNorms(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
|
||||
boolean success = false;
|
||||
NormsConsumer normsConsumer = null;
|
||||
|
@ -522,6 +467,7 @@ final class IndexingChain implements Accountable {
|
|||
// finalizer will e.g. close any open files in the term vectors writer:
|
||||
try (Closeable finalizer = termsHash::abort) {
|
||||
storedFieldsConsumer.abort();
|
||||
vectorValuesConsumer.abort();
|
||||
} finally {
|
||||
Arrays.fill(fieldHash, null);
|
||||
}
|
||||
|
@ -714,7 +660,12 @@ final class IndexingChain implements Accountable {
|
|||
pf.pointValuesWriter = new PointValuesWriter(bytesUsed, fi);
|
||||
}
|
||||
if (fi.getVectorDimension() != 0) {
|
||||
pf.vectorValuesWriter = new VectorValuesWriter(fi, bytesUsed);
|
||||
try {
|
||||
pf.knnFieldVectorsWriter = vectorValuesConsumer.addField(fi);
|
||||
} catch (Throwable th) {
|
||||
onAbortingException(th);
|
||||
throw th;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -761,7 +712,7 @@ final class IndexingChain implements Accountable {
|
|||
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
||||
}
|
||||
if (fieldType.vectorDimension() != 0) {
|
||||
pf.vectorValuesWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
|
||||
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
|
||||
}
|
||||
return indexedField;
|
||||
}
|
||||
|
@ -1006,12 +957,16 @@ final class IndexingChain implements Accountable {
|
|||
public long ramBytesUsed() {
|
||||
return bytesUsed.get()
|
||||
+ storedFieldsConsumer.accountable.ramBytesUsed()
|
||||
+ termVectorsWriter.accountable.ramBytesUsed();
|
||||
+ termVectorsWriter.accountable.ramBytesUsed()
|
||||
+ vectorValuesConsumer.getAccountable().ramBytesUsed();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Accountable> getChildResources() {
|
||||
return List.of(storedFieldsConsumer.accountable, termVectorsWriter.accountable);
|
||||
return List.of(
|
||||
storedFieldsConsumer.accountable,
|
||||
termVectorsWriter.accountable,
|
||||
vectorValuesConsumer.getAccountable());
|
||||
}
|
||||
|
||||
/** NOTE: not static: accesses at least docState, termsHash. */
|
||||
|
@ -1032,8 +987,8 @@ final class IndexingChain implements Accountable {
|
|||
// Non-null if this field ever had points in this segment:
|
||||
PointValuesWriter pointValuesWriter;
|
||||
|
||||
// Non-null if this field ever had vector values in this segment:
|
||||
VectorValuesWriter vectorValuesWriter;
|
||||
// Non-null if this field had vectors in this segment
|
||||
KnnFieldVectorsWriter knnFieldVectorsWriter;
|
||||
|
||||
/** We use this to know when a PerField is seen for the first time in the current document. */
|
||||
long fieldGen = -1;
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.lucene.util.packed.PackedLongValues;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
final class Sorter {
|
||||
public final class Sorter {
|
||||
final Sort sort;
|
||||
|
||||
/** Creates a new Sorter to sort the index with {@code sort} */
|
||||
|
@ -44,20 +44,23 @@ final class Sorter {
|
|||
* A permutation of doc IDs. For every document ID between <code>0</code> and {@link
|
||||
* IndexReader#maxDoc()}, <code>oldToNew(newToOld(docID))</code> must return <code>docID</code>.
|
||||
*/
|
||||
abstract static class DocMap {
|
||||
public abstract static class DocMap {
|
||||
|
||||
/** Sole constructor. */
|
||||
protected DocMap() {}
|
||||
|
||||
/** Given a doc ID from the original index, return its ordinal in the sorted index. */
|
||||
abstract int oldToNew(int docID);
|
||||
public abstract int oldToNew(int docID);
|
||||
|
||||
/** Given the ordinal of a doc ID, return its doc ID in the original index. */
|
||||
abstract int newToOld(int docID);
|
||||
public abstract int newToOld(int docID);
|
||||
|
||||
/**
|
||||
* Return the number of documents in this map. This must be equal to the {@link
|
||||
* org.apache.lucene.index.LeafReader#maxDoc() number of documents} of the {@link
|
||||
* org.apache.lucene.index.LeafReader} which is sorted.
|
||||
*/
|
||||
abstract int size();
|
||||
public abstract int size();
|
||||
}
|
||||
|
||||
/** Check consistency of a {@link DocMap}, useful for assertions. */
|
||||
|
|
|
@ -380,7 +380,7 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
return new VectorValuesWriter.SortingVectorValues(delegate.getVectorValues(field), docMap);
|
||||
return new VectorValues.SortingVectorValues(delegate.getVectorValues(field), docMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -111,4 +111,117 @@ public abstract class VectorValues extends DocIdSetIterator {
|
|||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
|
||||
public static class SortingVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValuesProducer {
|
||||
|
||||
private final VectorValues delegate;
|
||||
private final RandomAccessVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private final int[] ordMap;
|
||||
private int docId = -1;
|
||||
|
||||
/** Sorting VectorValues */
|
||||
public SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.delegate = delegate;
|
||||
randomAccess = ((RandomAccessVectorValuesProducer) delegate).randomAccess();
|
||||
docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
int docID;
|
||||
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
|
||||
// set up ordMap to map from new dense ordinal to old dense ordinal
|
||||
ordMap = new int[offset - 1];
|
||||
int ord = 0;
|
||||
for (int docIdOffset : docIdOffsets) {
|
||||
if (docIdOffset != 0) {
|
||||
ordMap[ord++] = docIdOffset - 1;
|
||||
}
|
||||
}
|
||||
assert ord == ordMap.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (docId < docIdOffsets.length - 1) {
|
||||
++docId;
|
||||
if (docIdOffsets[docId] != 0) {
|
||||
return docId;
|
||||
}
|
||||
}
|
||||
docId = NO_MORE_DOCS;
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() throws IOException {
|
||||
return randomAccess.binaryValue(docIdOffsets[docId] - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegate.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegate.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
|
||||
// Must make a new delegate randomAccess so that we have our own distinct float[]
|
||||
final RandomAccessVectorValues delegateRA =
|
||||
((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
|
||||
|
||||
return new RandomAccessVectorValues() {
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegateRA.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegateRA.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return delegateRA.vectorValue(ordMap[targetOrd]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
/**
|
||||
* Streams vector values for indexing to the given codec's vectors writer. The codec's vectors
|
||||
* writer is responsible for buffering and processing vectors.
|
||||
*/
|
||||
class VectorValuesConsumer {
|
||||
private final Codec codec;
|
||||
private final Directory directory;
|
||||
private final SegmentInfo segmentInfo;
|
||||
private final InfoStream infoStream;
|
||||
|
||||
private Accountable accountable = Accountable.NULL_ACCOUNTABLE;
|
||||
private KnnVectorsWriter writer;
|
||||
|
||||
VectorValuesConsumer(
|
||||
Codec codec, Directory directory, SegmentInfo segmentInfo, InfoStream infoStream) {
|
||||
this.codec = codec;
|
||||
this.directory = directory;
|
||||
this.segmentInfo = segmentInfo;
|
||||
this.infoStream = infoStream;
|
||||
}
|
||||
|
||||
private void initKnnVectorsWriter(String fieldName) throws IOException {
|
||||
if (writer == null) {
|
||||
KnnVectorsFormat fmt = codec.knnVectorsFormat();
|
||||
if (fmt == null) {
|
||||
throw new IllegalStateException(
|
||||
"field=\""
|
||||
+ fieldName
|
||||
+ "\" was indexed as vectors but codec does not support vectors");
|
||||
}
|
||||
SegmentWriteState initialWriteState =
|
||||
new SegmentWriteState(infoStream, directory, segmentInfo, null, null, IOContext.DEFAULT);
|
||||
writer = fmt.fieldsWriter(initialWriteState);
|
||||
accountable = writer;
|
||||
}
|
||||
}
|
||||
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
initKnnVectorsWriter(fieldInfo.name);
|
||||
return writer.addField(fieldInfo);
|
||||
}
|
||||
|
||||
void flush(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
|
||||
if (writer == null) return;
|
||||
try {
|
||||
writer.flush(state.segmentInfo.maxDoc(), sortMap);
|
||||
writer.finish();
|
||||
} finally {
|
||||
IOUtils.close(writer);
|
||||
}
|
||||
}
|
||||
|
||||
void abort() {
|
||||
IOUtils.closeWhileHandlingException(writer);
|
||||
}
|
||||
|
||||
public Accountable getAccountable() {
|
||||
return accountable;
|
||||
}
|
||||
}
|
|
@ -1,348 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.Counter;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
* Buffers up pending vector value(s) per doc, then flushes when segment flushes.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
class VectorValuesWriter {
|
||||
|
||||
private final FieldInfo fieldInfo;
|
||||
private final Counter iwBytesUsed;
|
||||
private final List<float[]> vectors = new ArrayList<>();
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
|
||||
private int lastDocID = -1;
|
||||
|
||||
private long bytesUsed;
|
||||
|
||||
VectorValuesWriter(FieldInfo fieldInfo, Counter iwBytesUsed) {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.iwBytesUsed = iwBytesUsed;
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
this.bytesUsed = docsWithField.ramBytesUsed();
|
||||
if (iwBytesUsed != null) {
|
||||
iwBytesUsed.addAndGet(bytesUsed);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a value for the given document. Only a single value may be added.
|
||||
*
|
||||
* @param docID the value is added to this document
|
||||
* @param vectorValue the value to add
|
||||
* @throws IllegalArgumentException if a value has already been added to the given document
|
||||
*/
|
||||
public void addValue(int docID, float[] vectorValue) {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
if (vectorValue.length != fieldInfo.getVectorDimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to index a vector of dimension "
|
||||
+ vectorValue.length
|
||||
+ " but \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has dimension "
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
||||
updateBytesUsed();
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
private void updateBytesUsed() {
|
||||
final long newBytesUsed =
|
||||
docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ vectors.size() * vectors.get(0).length * Float.BYTES;
|
||||
if (iwBytesUsed != null) {
|
||||
iwBytesUsed.addAndGet(newBytesUsed - bytesUsed);
|
||||
}
|
||||
bytesUsed = newBytesUsed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Flush this field's values to storage, sorting the values in accordance with sortMap
|
||||
*
|
||||
* @param sortMap specifies the order of documents being flushed, or null if they are to be
|
||||
* flushed in docid order
|
||||
* @param knnVectorsWriter the Codec's vector writer that handles the actual encoding and I/O
|
||||
* @throws IOException if there is an error writing the field and its values
|
||||
*/
|
||||
public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader =
|
||||
new KnnVectorsReader() {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
VectorValues vectorValues =
|
||||
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
|
||||
return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
|
||||
}
|
||||
|
||||
static class SortingVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValuesProducer {
|
||||
|
||||
private final VectorValues delegate;
|
||||
private final RandomAccessVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private final int[] ordMap;
|
||||
private int docId = -1;
|
||||
|
||||
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.delegate = delegate;
|
||||
randomAccess = ((RandomAccessVectorValuesProducer) delegate).randomAccess();
|
||||
docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
int docID;
|
||||
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
|
||||
// set up ordMap to map from new dense ordinal to old dense ordinal
|
||||
ordMap = new int[offset - 1];
|
||||
int ord = 0;
|
||||
for (int docIdOffset : docIdOffsets) {
|
||||
if (docIdOffset != 0) {
|
||||
ordMap[ord++] = docIdOffset - 1;
|
||||
}
|
||||
}
|
||||
assert ord == ordMap.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (docId < docIdOffsets.length - 1) {
|
||||
++docId;
|
||||
if (docIdOffsets[docId] != 0) {
|
||||
return docId;
|
||||
}
|
||||
}
|
||||
docId = NO_MORE_DOCS;
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() throws IOException {
|
||||
return randomAccess.binaryValue(docIdOffsets[docId] - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegate.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegate.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
|
||||
// Must make a new delegate randomAccess so that we have our own distinct float[]
|
||||
final RandomAccessVectorValues delegateRA =
|
||||
((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
|
||||
|
||||
return new RandomAccessVectorValues() {
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegateRA.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegateRA.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return delegateRA.vectorValue(ordMap[targetOrd]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class BufferedVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
|
||||
final List<float[]> vectors;
|
||||
final int dimension;
|
||||
|
||||
final ByteBuffer buffer;
|
||||
final BytesRef binaryValue;
|
||||
final ByteBuffer raBuffer;
|
||||
final BytesRef raBinaryValue;
|
||||
|
||||
DocIdSetIterator docsWithFieldIter;
|
||||
int ord = -1;
|
||||
|
||||
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
|
||||
this.docsWithField = docsWithField;
|
||||
this.vectors = vectors;
|
||||
this.dimension = dimension;
|
||||
buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
binaryValue = new BytesRef(buffer.array());
|
||||
raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
raBinaryValue = new BytesRef(raBuffer.array());
|
||||
docsWithFieldIter = docsWithField.iterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
return new BufferedVectorValues(docsWithField, vectors, dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() {
|
||||
buffer.asFloatBuffer().put(vectorValue());
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
raBuffer.asFloatBuffer().put(vectors.get(targetOrd));
|
||||
return raBinaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
return vectors.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docsWithFieldIter.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int docID = docsWithFieldIter.nextDoc();
|
||||
if (docID != NO_MORE_DOCS) {
|
||||
++ord;
|
||||
}
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return docsWithFieldIter.cost();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -137,8 +137,12 @@ public final class HnswGraphBuilder {
|
|||
this.infoStream = infoStream;
|
||||
}
|
||||
|
||||
public OnHeapHnswGraph getGraph() {
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
/** Inserts a doc with vector value to the graph */
|
||||
void addGraphNode(int node, float[] value) throws IOException {
|
||||
public void addGraphNode(int node, float[] value) throws IOException {
|
||||
NeighborQueue candidates;
|
||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||
int curMaxLevel = hnsw.numLevels() - 1;
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.index.RandomAccessVectorValues;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
|
||||
/**
|
||||
|
@ -38,7 +39,7 @@ public final class HnswGraphSearcher {
|
|||
*/
|
||||
private final NeighborQueue candidates;
|
||||
|
||||
private final BitSet visited;
|
||||
private BitSet visited;
|
||||
|
||||
/**
|
||||
* Creates a new graph searcher.
|
||||
|
@ -140,7 +141,7 @@ public final class HnswGraphSearcher {
|
|||
throws IOException {
|
||||
int size = graph.size();
|
||||
NeighborQueue results = new NeighborQueue(topK, false);
|
||||
clearScratchState();
|
||||
prepareScratchState(vectors.size());
|
||||
|
||||
int numVisited = 0;
|
||||
for (int ep : eps) {
|
||||
|
@ -203,8 +204,11 @@ public final class HnswGraphSearcher {
|
|||
return results;
|
||||
}
|
||||
|
||||
private void clearScratchState() {
|
||||
private void prepareScratchState(int capacity) {
|
||||
candidates.clear();
|
||||
if (visited.length() < capacity) {
|
||||
visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity);
|
||||
}
|
||||
visited.clear(0, visited.length());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,13 +22,15 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||
* construct the HNSW graph before it's written to the index.
|
||||
*/
|
||||
public final class OnHeapHnswGraph extends HnswGraph {
|
||||
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
|
@ -167,4 +169,28 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
|||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long neighborArrayBytes0 =
|
||||
nsize0 * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
long neighborArrayBytes =
|
||||
nsize * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
|
||||
long total = 0;
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
int numNodesOnLevel = graph.get(l).size();
|
||||
if (l == 0) {
|
||||
total += numNodesOnLevel * neighborArrayBytes0; // for graph;
|
||||
} else {
|
||||
total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel
|
||||
total += numNodesOnLevel * neighborArrayBytes; // for graph;
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashSet;
|
|||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
|
@ -41,6 +42,7 @@ import org.apache.lucene.index.MergeState;
|
|||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
|
@ -174,19 +176,22 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
KnnVectorsWriter writer = delegate.fieldsWriter(state);
|
||||
return new KnnVectorsWriter() {
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
fieldsWritten.add(fieldInfo.name);
|
||||
writer.writeField(fieldInfo, knnVectorsReader);
|
||||
return writer.addField(fieldInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void merge(MergeState mergeState) throws IOException {
|
||||
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
|
||||
fieldsWritten.add(fieldInfo.name);
|
||||
}
|
||||
writer.merge(mergeState);
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
writer.flush(maxDoc, sortMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
fieldsWritten.add(fieldInfo.name);
|
||||
writer.mergeOneField(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -198,6 +203,11 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
public void close() throws IOException {
|
||||
writer.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return writer.ramBytesUsed();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,10 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
|
@ -31,6 +33,7 @@ import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsReader;
|
|||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
|
@ -42,6 +45,12 @@ import org.apache.lucene.index.RandomAccessVectorValues;
|
|||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnVectorQuery;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
@ -122,6 +131,94 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
// test that sorted index returns the same search results are unsorted
|
||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||
int dim = random().nextInt(10) + 3;
|
||||
int nDoc = random().nextInt(500) + 1;
|
||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
||||
|
||||
int M = random().nextInt(10) + 5;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
long seed = random().nextLong();
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
IndexWriterConfig iwc =
|
||||
new IndexWriterConfig()
|
||||
.setCodec(
|
||||
new Lucene94Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene94HnswVectorsFormat(M, beamWidth);
|
||||
}
|
||||
});
|
||||
IndexWriterConfig iwc2 =
|
||||
new IndexWriterConfig()
|
||||
.setCodec(
|
||||
new Lucene94Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene94HnswVectorsFormat(M, beamWidth);
|
||||
}
|
||||
})
|
||||
.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.LONG)));
|
||||
;
|
||||
|
||||
try (Directory dir = newDirectory();
|
||||
Directory dir2 = newDirectory()) {
|
||||
int indexedDoc = 0;
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc);
|
||||
IndexWriter iw2 = new IndexWriter(dir2, iwc2)) {
|
||||
while (vectors.nextDoc() != NO_MORE_DOCS) {
|
||||
while (indexedDoc < vectors.docID()) {
|
||||
// increment docId in the index by adding empty documents
|
||||
iw.addDocument(new Document());
|
||||
indexedDoc++;
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("vector", vectors.vectorValue(), similarityFunction));
|
||||
doc.add(new StoredField("id", vectors.docID()));
|
||||
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
|
||||
iw.addDocument(doc);
|
||||
iw2.addDocument(doc);
|
||||
indexedDoc++;
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir);
|
||||
IndexReader reader2 = DirectoryReader.open(dir2)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
IndexSearcher searcher2 = new IndexSearcher(reader2);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
// ask to explore a lot of candidates to ensure the same returned hits,
|
||||
// as graphs of 2 indices are organized differently
|
||||
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(random(), dim), 50);
|
||||
List<String> ids1 = new ArrayList<>();
|
||||
List<Integer> docs1 = new ArrayList<>();
|
||||
List<String> ids2 = new ArrayList<>();
|
||||
List<Integer> docs2 = new ArrayList<>();
|
||||
|
||||
TopDocs topDocs = searcher.search(query, 5);
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
Document doc = reader.document(scoreDoc.doc, Set.of("id"));
|
||||
ids1.add(doc.get("id"));
|
||||
docs1.add(scoreDoc.doc);
|
||||
}
|
||||
TopDocs topDocs2 = searcher2.search(query, 5);
|
||||
for (ScoreDoc scoreDoc : topDocs2.scoreDocs) {
|
||||
Document doc = reader2.document(scoreDoc.doc, Set.of("id"));
|
||||
ids2.add(doc.get("id"));
|
||||
docs2.add(scoreDoc.doc);
|
||||
}
|
||||
assertEquals(ids1, ids2);
|
||||
// doc IDs are not equal, as in the second sorted index docs are organized differently
|
||||
assertNotEquals(docs1, docs2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.tests.codecs.asserting;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
|
@ -26,6 +27,7 @@ import org.apache.lucene.index.FieldInfos;
|
|||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
@ -59,20 +61,20 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
assert fieldInfo != null;
|
||||
assert knnVectorsReader != null;
|
||||
// assert that knnVectorsReader#getVectorValues returns different instances upon repeated
|
||||
// calls
|
||||
assert knnVectorsReader.getVectorValues(fieldInfo.name)
|
||||
!= knnVectorsReader.getVectorValues(fieldInfo.name);
|
||||
delegate.writeField(fieldInfo, knnVectorsReader);
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
return delegate.addField(fieldInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void merge(MergeState mergeState) throws IOException {
|
||||
delegate.merge(mergeState);
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
delegate.flush(maxDoc, sortMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
assert fieldInfo != null;
|
||||
assert mergeState != null;
|
||||
delegate.mergeOneField(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -84,6 +86,11 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
public void close() throws IOException {
|
||||
delegate.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return delegate.ramBytesUsed();
|
||||
}
|
||||
}
|
||||
|
||||
static class AssertingKnnVectorsReader extends KnnVectorsReader {
|
||||
|
|
Loading…
Reference in New Issue