mirror of https://github.com/apache/lucene.git
Extract the hnsw graph merging from being part of the vector writer (#12657)
While working on the quantization codec & thinking about how merging will evolve, it became clearer that having merging attached directly to the vector writer is weird. I extracted it out to its own class and removed the "initializedNodes" logic from the base class builder. Also, there was on other refactoring around grabbing sorted nodes from the neighbor iterator, I just moved that static method so its not attached to the writer (as all bwc writers need it and all future HNSW writers will as well).
This commit is contained in:
parent
218eddec70
commit
ea272d0eda
|
@ -226,7 +226,8 @@ Build
|
|||
|
||||
Other
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
* GITHUB#12657: Internal refactor of HNSW graph merging (Ben Trent).
|
||||
|
||||
======================== Lucene 9.8.0 =======================
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ import java.nio.ByteOrder;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -37,6 +36,7 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
|
@ -227,7 +227,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
} else {
|
||||
meta.writeInt(graph.numLevels());
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
meta.writeInt(sortedNodes.length); // number of nodes on a level
|
||||
if (level > 0) {
|
||||
for (int node : sortedNodes) {
|
||||
|
@ -256,7 +256,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
for (int node : sortedNodes) {
|
||||
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
|
|
|
@ -27,7 +27,6 @@ import java.util.Arrays;
|
|||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -39,6 +38,7 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
|
@ -261,7 +261,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
} else {
|
||||
meta.writeInt(graph.numLevels());
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
meta.writeInt(sortedNodes.length); // number of nodes on a level
|
||||
if (level > 0) {
|
||||
for (int node : sortedNodes) {
|
||||
|
@ -289,7 +289,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
for (int node : sortedNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.codecs.CodecUtil;
|
|||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -477,7 +476,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
for (int node : sortedNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
|
@ -565,7 +564,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
} else {
|
||||
meta.writeInt(graph.numLevels());
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
meta.writeInt(sortedNodes.length); // number of nodes on a level
|
||||
if (level > 0) {
|
||||
for (int node : sortedNodes) {
|
||||
|
|
|
@ -25,16 +25,10 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.*;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -423,51 +417,49 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
OnHeapHnswGraph graph = null;
|
||||
int[][] vectorIndexNodeOffsets = null;
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
// build graph
|
||||
int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
|
||||
graph =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||
final RandomVectorScorerSupplier scorerSupplier;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
scorerSupplier =
|
||||
RandomVectorScorerSupplier.createBytes(
|
||||
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
RandomVectorScorerSupplier.createBytes(
|
||||
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
createHnswGraphBuilder(
|
||||
mergeState,
|
||||
fieldInfo,
|
||||
scorerSupplier,
|
||||
initializerIndex,
|
||||
vectorValues.size());
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.size());
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
OffHeapFloatVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||
byteSize),
|
||||
fieldInfo.getVectorSimilarityFunction());
|
||||
break;
|
||||
case FLOAT32:
|
||||
scorerSupplier =
|
||||
RandomVectorScorerSupplier.createFloats(
|
||||
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
RandomVectorScorerSupplier.createFloats(
|
||||
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
createHnswGraphBuilder(
|
||||
mergeState,
|
||||
fieldInfo,
|
||||
scorerSupplier,
|
||||
initializerIndex,
|
||||
vectorValues.size());
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.size());
|
||||
}
|
||||
};
|
||||
byteSize),
|
||||
fieldInfo.getVectorSimilarityFunction());
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Unsupported vector encoding: " + fieldInfo.getVectorEncoding());
|
||||
}
|
||||
// build graph
|
||||
IncrementalHnswGraphMerger merger =
|
||||
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
merger.addReader(
|
||||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||
}
|
||||
DocIdSetIterator mergedVectorIterator = null;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> mergedVectorIterator =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
case FLOAT32 -> mergedVectorIterator =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(docsWithField.cardinality());
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
@ -494,185 +486,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private HnswGraphBuilder createHnswGraphBuilder(
|
||||
MergeState mergeState,
|
||||
FieldInfo fieldInfo,
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int initializerIndex,
|
||||
int graphSize)
|
||||
throws IOException {
|
||||
if (initializerIndex == -1) {
|
||||
return HnswGraphBuilder.create(
|
||||
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph =
|
||||
getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
|
||||
Map<Integer, Integer> ordinalMapper =
|
||||
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
|
||||
return HnswGraphBuilder.create(
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed,
|
||||
initializerGraph,
|
||||
ordinalMapper,
|
||||
graphSize);
|
||||
}
|
||||
|
||||
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
|
||||
throws IOException {
|
||||
// Find the KnnVectorReader with the most docs that meets the following criteria:
|
||||
// 1. Does not contain any deleted docs
|
||||
// 2. Is a HnswGraphProvider/PerFieldKnnVectorReader
|
||||
// If no readers exist that meet this criteria, return -1. If they do, return their index in
|
||||
// merge state
|
||||
int maxCandidateVectorCount = 0;
|
||||
int initializerIndex = -1;
|
||||
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
if (mergeState.knnVectorsReaders[i]
|
||||
instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
|
||||
}
|
||||
|
||||
if (!allMatch(mergeState.liveDocs[i])
|
||||
|| !(currKnnVectorsReader instanceof HnswGraphProvider)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int candidateVectorCount = 0;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
ByteVectorValues byteVectorValues =
|
||||
currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||
if (byteVectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
candidateVectorCount = byteVectorValues.size();
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
FloatVectorValues vectorValues =
|
||||
currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
candidateVectorCount = vectorValues.size();
|
||||
}
|
||||
}
|
||||
|
||||
if (candidateVectorCount > maxCandidateVectorCount) {
|
||||
maxCandidateVectorCount = candidateVectorCount;
|
||||
initializerIndex = i;
|
||||
}
|
||||
}
|
||||
return initializerIndex;
|
||||
}
|
||||
|
||||
private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
|
||||
&& perFieldReader.getFieldReader(fieldName) instanceof HnswGraphProvider fieldReader) {
|
||||
return fieldReader.getGraph(fieldName);
|
||||
}
|
||||
|
||||
if (knnVectorsReader instanceof HnswGraphProvider provider) {
|
||||
return provider.getGraph(fieldName);
|
||||
}
|
||||
|
||||
// We should not reach here because knnVectorsReader's type is checked in
|
||||
// selectGraphForInitialization
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid KnnVectorsReader type for field: "
|
||||
+ fieldName
|
||||
+ ". Must be Lucene95HnswVectorsReader or newer");
|
||||
}
|
||||
|
||||
private Map<Integer, Integer> getOldToNewOrdinalMap(
|
||||
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {
|
||||
|
||||
DocIdSetIterator initializerIterator = null;
|
||||
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> initializerIterator =
|
||||
mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
|
||||
case FLOAT32 -> initializerIterator =
|
||||
mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
|
||||
}
|
||||
|
||||
MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];
|
||||
|
||||
Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
|
||||
int oldOrd = 0;
|
||||
int maxNewDocID = -1;
|
||||
for (int oldId = initializerIterator.nextDoc();
|
||||
oldId != NO_MORE_DOCS;
|
||||
oldId = initializerIterator.nextDoc()) {
|
||||
if (isCurrentVectorNull(initializerIterator)) {
|
||||
continue;
|
||||
}
|
||||
int newId = initializerDocMap.get(oldId);
|
||||
maxNewDocID = Math.max(newId, maxNewDocID);
|
||||
newIdToOldOrdinal.put(newId, oldOrd);
|
||||
oldOrd++;
|
||||
}
|
||||
|
||||
if (maxNewDocID == -1) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
|
||||
|
||||
DocIdSetIterator vectorIterator = null;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
case FLOAT32 -> vectorIterator =
|
||||
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
int newOrd = 0;
|
||||
for (int newDocId = vectorIterator.nextDoc();
|
||||
newDocId <= maxNewDocID;
|
||||
newDocId = vectorIterator.nextDoc()) {
|
||||
if (isCurrentVectorNull(vectorIterator)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (newIdToOldOrdinal.containsKey(newDocId)) {
|
||||
oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
|
||||
}
|
||||
newOrd++;
|
||||
}
|
||||
|
||||
return oldToNewOrdinalMap;
|
||||
}
|
||||
|
||||
private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
|
||||
if (docIdSetIterator instanceof FloatVectorValues) {
|
||||
return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
|
||||
}
|
||||
|
||||
if (docIdSetIterator instanceof ByteVectorValues) {
|
||||
return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean allMatch(Bits bits) {
|
||||
if (bits == null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < bits.length(); i++) {
|
||||
if (!bits.get(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param graph Write the graph in a compressed format
|
||||
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
|
||||
|
@ -684,7 +497,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int countOnLevel0 = graph.size();
|
||||
int[][] offsets = new int[graph.numLevels()][];
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
|
||||
int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
offsets[level] = new int[sortedNodes.length];
|
||||
int nodeOffsetId = 0;
|
||||
for (int node : sortedNodes) {
|
||||
|
@ -712,15 +525,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
return offsets;
|
||||
}
|
||||
|
||||
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
|
||||
int[] sortedNodes = new int[nodesOnLevel.size()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
sortedNodes[n] = nodesOnLevel.nextInt();
|
||||
}
|
||||
Arrays.sort(sortedNodes);
|
||||
return sortedNodes;
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
FieldInfo field,
|
||||
int maxDoc,
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.NoSuchElementException;
|
||||
|
@ -152,6 +153,15 @@ public abstract class HnswGraph {
|
|||
* @return The number of integers written to `dest`
|
||||
*/
|
||||
public abstract int consume(int[] dest);
|
||||
|
||||
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
|
||||
int[] sortedNodes = new int[nodesOnLevel.size()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
sortedNodes[n] = nodesOnLevel.nextInt();
|
||||
}
|
||||
Arrays.sort(sortedNodes);
|
||||
return sortedNodes;
|
||||
}
|
||||
}
|
||||
|
||||
/** NodesIterator that accepts nodes as an integer array. */
|
||||
|
|
|
@ -18,14 +18,10 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static java.lang.Math.log;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.SplittableRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
|
@ -37,7 +33,7 @@ import org.apache.lucene.util.InfoStream;
|
|||
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
||||
* hyper-parameters.
|
||||
*/
|
||||
public final class HnswGraphBuilder {
|
||||
public class HnswGraphBuilder {
|
||||
|
||||
/** Default number of maximum connections per node */
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
|
@ -67,12 +63,10 @@ public final class HnswGraphBuilder {
|
|||
private final GraphBuilderKnnCollector
|
||||
beamCandidates; // for levels of graph where we add the node
|
||||
|
||||
final OnHeapHnswGraph hnsw;
|
||||
protected final OnHeapHnswGraph hnsw;
|
||||
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
||||
private final Set<Integer> initializedNodes;
|
||||
|
||||
public static HnswGraphBuilder create(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
||||
throws IOException {
|
||||
|
@ -80,23 +74,9 @@ public final class HnswGraphBuilder {
|
|||
}
|
||||
|
||||
public static HnswGraphBuilder create(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
|
||||
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||
}
|
||||
|
||||
public static HnswGraphBuilder create(
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
HnswGraph initializerGraph,
|
||||
Map<Integer, Integer> oldToNewOrdinalMap,
|
||||
int graphSize)
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize)
|
||||
throws IOException {
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
|
||||
return hnswGraphBuilder;
|
||||
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -111,8 +91,31 @@ public final class HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
* @param graphSize size of graph, if unknown, pass in -1
|
||||
*/
|
||||
private HnswGraphBuilder(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
|
||||
protected HnswGraphBuilder(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize)
|
||||
throws IOException {
|
||||
this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from vector values, builds a graph connecting them by their dense
|
||||
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
|
||||
*
|
||||
* @param scorerSupplier a supplier to create vector scorer from ordinals.
|
||||
* @param M – graph fanout parameter used to calculate the maximum number of connections a node
|
||||
* can have – M on upper layers, and M * 2 on the lowest level.
|
||||
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||
* to ensure repeatable construction.
|
||||
* @param hnsw the graph to build, can be previously initialized
|
||||
*/
|
||||
protected HnswGraphBuilder(
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
OnHeapHnswGraph hnsw)
|
||||
throws IOException {
|
||||
if (M <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
}
|
||||
|
@ -125,7 +128,7 @@ public final class HnswGraphBuilder {
|
|||
// normalization factor for level generation; currently not configurable
|
||||
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
this.hnsw = new OnHeapHnswGraph(M, graphSize);
|
||||
this.hnsw = hnsw;
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
|
||||
|
@ -133,7 +136,6 @@ public final class HnswGraphBuilder {
|
|||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
|
||||
entryCandidates = new GraphBuilderKnnCollector(1);
|
||||
beamCandidates = new GraphBuilderKnnCollector(beamWidth);
|
||||
this.initializedNodes = new HashSet<>();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -149,45 +151,6 @@ public final class HnswGraphBuilder {
|
|||
return hnsw;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the graph of this builder. Transfers the nodes and their neighbors from the
|
||||
* initializer graph into the graph being produced by this builder, mapping ordinals from the
|
||||
* initializer graph to their new ordinals in this builder's graph. The builder's graph must be
|
||||
* empty before calling this method.
|
||||
*
|
||||
* @param initializerGraph graph used for initialization
|
||||
* @param oldToNewOrdinalMap map for converting from ordinals in the initializerGraph to this
|
||||
* builder's graph
|
||||
*/
|
||||
private void initializeFromGraph(
|
||||
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
|
||||
assert hnsw.size() == 0;
|
||||
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
|
||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||
|
||||
while (it.hasNext()) {
|
||||
int oldOrd = it.nextInt();
|
||||
int newOrd = oldToNewOrdinalMap.get(oldOrd);
|
||||
|
||||
hnsw.addNode(level, newOrd);
|
||||
|
||||
if (level == 0) {
|
||||
initializedNodes.add(newOrd);
|
||||
}
|
||||
|
||||
NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
|
||||
initializerGraph.seek(level, oldOrd);
|
||||
for (int oldNeighbor = initializerGraph.nextNeighbor();
|
||||
oldNeighbor != NO_MORE_DOCS;
|
||||
oldNeighbor = initializerGraph.nextNeighbor()) {
|
||||
int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
|
||||
// we will compute these scores later when we need to pop out the non-diverse nodes
|
||||
newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Set info-stream to output debugging information * */
|
||||
public void setInfoStream(InfoStream infoStream) {
|
||||
this.infoStream = infoStream;
|
||||
|
@ -200,9 +163,6 @@ public final class HnswGraphBuilder {
|
|||
private void addVectors(int maxOrd) throws IOException {
|
||||
long start = System.nanoTime(), t = start;
|
||||
for (int node = 0; node < maxOrd; node++) {
|
||||
if (initializedNodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
addGraphNode(node);
|
||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
t = printGraphBuildStatus(node, start, t);
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.CollectionUtil;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/**
|
||||
* This selects the biggest Hnsw graph from the provided merge state and initializes a new
|
||||
* HnswGraphBuilder with that graph as a starting point.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class IncrementalHnswGraphMerger {
|
||||
|
||||
private KnnVectorsReader initReader;
|
||||
private MergeState.DocMap initDocMap;
|
||||
private int initGraphSize;
|
||||
private final FieldInfo fieldInfo;
|
||||
private final RandomVectorScorerSupplier scorerSupplier;
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
|
||||
/**
|
||||
* @param fieldInfo FieldInfo for the field being merged
|
||||
*/
|
||||
public IncrementalHnswGraphMerger(
|
||||
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth) {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.scorerSupplier = scorerSupplier;
|
||||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a reader to the graph merger if it meets the following criteria: 1. Does not contain any
|
||||
* deleted docs 2. Is a HnswGraphProvider/PerFieldKnnVectorReader 3. Has the most docs of any
|
||||
* previous reader that met the above criteria
|
||||
*
|
||||
* @param reader KnnVectorsReader to add to the merger
|
||||
* @param docMap MergeState.DocMap for the reader
|
||||
* @param liveDocs Bits representing live docs, can be null
|
||||
* @return this
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
public IncrementalHnswGraphMerger addReader(
|
||||
KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException {
|
||||
KnnVectorsReader currKnnVectorsReader = reader;
|
||||
if (reader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
|
||||
}
|
||||
|
||||
if (!(currKnnVectorsReader instanceof HnswGraphProvider) || !noDeletes(liveDocs)) {
|
||||
return this;
|
||||
}
|
||||
|
||||
int candidateVectorCount = 0;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
ByteVectorValues byteVectorValues =
|
||||
currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||
if (byteVectorValues == null) {
|
||||
return this;
|
||||
}
|
||||
candidateVectorCount = byteVectorValues.size();
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
FloatVectorValues vectorValues = currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
if (vectorValues == null) {
|
||||
return this;
|
||||
}
|
||||
candidateVectorCount = vectorValues.size();
|
||||
}
|
||||
}
|
||||
if (candidateVectorCount > initGraphSize) {
|
||||
initReader = currKnnVectorsReader;
|
||||
initDocMap = docMap;
|
||||
initGraphSize = candidateVectorCount;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a new HnswGraphBuilder using the biggest graph from the merge state as a starting point.
|
||||
* If no valid readers were added to the merge state, a new graph is created.
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @return HnswGraphBuilder
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
public HnswGraphBuilder createBuilder(DocIdSetIterator mergedVectorIterator) throws IOException {
|
||||
if (initReader == null) {
|
||||
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
|
||||
final int numVectors = Math.toIntExact(mergedVectorIterator.cost());
|
||||
|
||||
BitSet initializedNodes = new FixedBitSet(numVectors + 1);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
|
||||
return InitializedHnswGraphBuilder.fromGraph(
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed,
|
||||
initializerGraph,
|
||||
oldToNewOrdinalMap,
|
||||
initializedNodes,
|
||||
numVectors);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors
|
||||
* in the newly merged segment.
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @param initializedNodes track what nodes have been initialized
|
||||
* @return the mapping from old ordinals to new ordinals
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
private int[] getNewOrdMapping(DocIdSetIterator mergedVectorIterator, BitSet initializedNodes)
|
||||
throws IOException {
|
||||
DocIdSetIterator initializerIterator = null;
|
||||
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name);
|
||||
case FLOAT32 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name);
|
||||
}
|
||||
|
||||
Map<Integer, Integer> newIdToOldOrdinal = CollectionUtil.newHashMap(initGraphSize);
|
||||
int oldOrd = 0;
|
||||
int maxNewDocID = -1;
|
||||
for (int oldId = initializerIterator.nextDoc();
|
||||
oldId != NO_MORE_DOCS;
|
||||
oldId = initializerIterator.nextDoc()) {
|
||||
int newId = initDocMap.get(oldId);
|
||||
maxNewDocID = Math.max(newId, maxNewDocID);
|
||||
newIdToOldOrdinal.put(newId, oldOrd);
|
||||
oldOrd++;
|
||||
}
|
||||
|
||||
if (maxNewDocID == -1) {
|
||||
return new int[0];
|
||||
}
|
||||
final int[] oldToNewOrdinalMap = new int[initGraphSize];
|
||||
int newOrd = 0;
|
||||
for (int newDocId = mergedVectorIterator.nextDoc();
|
||||
newDocId <= maxNewDocID;
|
||||
newDocId = mergedVectorIterator.nextDoc()) {
|
||||
if (newIdToOldOrdinal.containsKey(newDocId)) {
|
||||
initializedNodes.set(newOrd);
|
||||
oldToNewOrdinalMap[newIdToOldOrdinal.get(newDocId)] = newOrd;
|
||||
}
|
||||
newOrd++;
|
||||
}
|
||||
return oldToNewOrdinalMap;
|
||||
}
|
||||
|
||||
private static boolean noDeletes(Bits liveDocs) {
|
||||
if (liveDocs == null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < liveDocs.length(); i++) {
|
||||
if (!liveDocs.get(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
|
||||
/**
|
||||
* This creates a graph builder that is initialized with the provided HnswGraph. This is useful for
|
||||
* merging HnswGraphs from multiple segments.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
|
||||
|
||||
/**
|
||||
* Create a new HnswGraphBuilder that is initialized with the provided HnswGraph.
|
||||
*
|
||||
* @param scorerSupplier the scorer to use for vectors
|
||||
* @param M the number of connections to keep per node
|
||||
* @param beamWidth the number of nodes to explore in the search
|
||||
* @param seed the seed for the random number generator
|
||||
* @param initializerGraph the graph to initialize the new graph builder
|
||||
* @param newOrdMap a mapping from the old node ordinal to the new node ordinal
|
||||
* @param initializedNodes a bitset of nodes that are already initialized in the initializerGraph
|
||||
* @param totalNumberOfVectors the total number of vectors in the new graph, this should include
|
||||
* all vectors expected to be added to the graph in the future
|
||||
* @return a new HnswGraphBuilder that is initialized with the provided HnswGraph
|
||||
* @throws IOException when reading the graph fails
|
||||
*/
|
||||
public static InitializedHnswGraphBuilder fromGraph(
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
HnswGraph initializerGraph,
|
||||
int[] newOrdMap,
|
||||
BitSet initializedNodes,
|
||||
int totalNumberOfVectors)
|
||||
throws IOException {
|
||||
OnHeapHnswGraph hnsw = new OnHeapHnswGraph(M, totalNumberOfVectors);
|
||||
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
|
||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||
while (it.hasNext()) {
|
||||
int oldOrd = it.nextInt();
|
||||
int newOrd = newOrdMap[oldOrd];
|
||||
hnsw.addNode(level, newOrd);
|
||||
NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd);
|
||||
initializerGraph.seek(level, oldOrd);
|
||||
for (int oldNeighbor = initializerGraph.nextNeighbor();
|
||||
oldNeighbor != NO_MORE_DOCS;
|
||||
oldNeighbor = initializerGraph.nextNeighbor()) {
|
||||
int newNeighbor = newOrdMap[oldNeighbor];
|
||||
// we will compute these scores later when we need to pop out the non-diverse nodes
|
||||
newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
|
||||
}
|
||||
}
|
||||
}
|
||||
return new InitializedHnswGraphBuilder(
|
||||
scorerSupplier, M, beamWidth, seed, hnsw, initializedNodes);
|
||||
}
|
||||
|
||||
private final BitSet initializedNodes;
|
||||
|
||||
public InitializedHnswGraphBuilder(
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
OnHeapHnswGraph initializedGraph,
|
||||
BitSet initializedNodes)
|
||||
throws IOException {
|
||||
super(scorerSupplier, M, beamWidth, seed, initializedGraph);
|
||||
this.initializedNodes = initializedNodes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addGraphNode(int node) throws IOException {
|
||||
if (initializedNodes.get(node)) {
|
||||
return;
|
||||
}
|
||||
super.addGraphNode(node);
|
||||
}
|
||||
}
|
|
@ -26,11 +26,9 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
@ -49,6 +47,7 @@ import org.apache.lucene.document.Document;
|
|||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
|
@ -59,8 +58,10 @@ import org.apache.lucene.index.IndexWriterConfig;
|
|||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.Query;
|
||||
|
@ -132,6 +133,73 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
};
|
||||
}
|
||||
|
||||
// Tests writing segments of various sizes and merging to ensure there are no errors
|
||||
// in the HNSW graph merging logic.
|
||||
public void testRandomReadWriteAndMerge() throws IOException {
|
||||
int dim = random().nextInt(100) + 1;
|
||||
int[] segmentSizes =
|
||||
new int[] {random().nextInt(20) + 1, random().nextInt(10) + 30, random().nextInt(10) + 20};
|
||||
// Randomly delete vector documents
|
||||
boolean[] addDeletes =
|
||||
new boolean[] {random().nextBoolean(), random().nextBoolean(), random().nextBoolean()};
|
||||
// randomly index other documents besides vector docs
|
||||
boolean[] isSparse =
|
||||
new boolean[] {random().nextBoolean(), random().nextBoolean(), random().nextBoolean()};
|
||||
int numVectors = segmentSizes[0] + segmentSizes[1] + segmentSizes[2];
|
||||
int M = random().nextInt(4) + 2;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(numVectors, dim);
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
|
||||
try (Directory dir = newDirectory()) {
|
||||
IndexWriterConfig iwc =
|
||||
new IndexWriterConfig()
|
||||
.setCodec(
|
||||
new Lucene95Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, beamWidth);
|
||||
}
|
||||
})
|
||||
// set a random merge policy
|
||||
.setMergePolicy(newMergePolicy(random()));
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
for (int i = 0; i < segmentSizes.length; i++) {
|
||||
int size = segmentSizes[i];
|
||||
while (vectors.nextDoc() < size) {
|
||||
if (isSparse[i] && random().nextBoolean()) {
|
||||
int d = random().nextInt(10) + 1;
|
||||
for (int j = 0; j < d; j++) {
|
||||
Document doc = new Document();
|
||||
iw.addDocument(doc);
|
||||
}
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(knnVectorField("field", vectors.vectorValue(), similarityFunction));
|
||||
doc.add(new StringField("id", Integer.toString(vectors.docID()), Field.Store.NO));
|
||||
iw.addDocument(doc);
|
||||
}
|
||||
iw.commit();
|
||||
if (addDeletes[i] && size > 1) {
|
||||
for (int d = 0; d < size; d += random().nextInt(5) + 1) {
|
||||
iw.deleteDocuments(new Term("id", Integer.toString(d)));
|
||||
}
|
||||
iw.commit();
|
||||
}
|
||||
}
|
||||
iw.commit();
|
||||
iw.forceMerge(1);
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
|
||||
assertEquals(dim, values.dimension());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
public void testReadWrite() throws IOException {
|
||||
int dim = random().nextInt(100) + 1;
|
||||
|
@ -469,20 +537,21 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
|
||||
AbstractMockVectorValues<T> finalVectorValues =
|
||||
vectorValues(totalSize, dim, initializerVectors, docIdOffset);
|
||||
|
||||
Map<Integer, Integer> initializerOrdMap =
|
||||
int[] initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
|
||||
|
||||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||
HnswGraphBuilder finalBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
InitializedHnswGraphBuilder.fromGraph(
|
||||
finalscorerSupplier,
|
||||
10,
|
||||
30,
|
||||
seed,
|
||||
initializerGraph,
|
||||
initializerOrdMap,
|
||||
finalVectorValues.size());
|
||||
BitSet.of(
|
||||
DocIdSetIterator.range(docIdOffset, initializerSize + docIdOffset), totalSize + 1),
|
||||
totalSize);
|
||||
|
||||
// When offset is 0, the graphs should be identical before vectors are added
|
||||
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
|
||||
|
@ -506,19 +575,21 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
|
||||
AbstractMockVectorValues<T> finalVectorValues =
|
||||
vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
|
||||
Map<Integer, Integer> initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset);
|
||||
int[] initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
|
||||
|
||||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||
HnswGraphBuilder finalBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
InitializedHnswGraphBuilder.fromGraph(
|
||||
finalscorerSupplier,
|
||||
10,
|
||||
30,
|
||||
seed,
|
||||
initializerGraph,
|
||||
initializerOrdMap,
|
||||
finalVectorValues.size());
|
||||
BitSet.of(
|
||||
DocIdSetIterator.range(docIdOffset, initializerSize + docIdOffset), totalSize + 1),
|
||||
totalSize);
|
||||
|
||||
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
|
||||
|
||||
|
@ -528,19 +599,19 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
|
||||
}
|
||||
|
||||
private void assertGraphContainsGraph(
|
||||
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||
for (int i = 0; i < h.numLevels(); i++) {
|
||||
private void assertGraphContainsGraph(HnswGraph g, HnswGraph initializer, int[] newOrdinals)
|
||||
throws IOException {
|
||||
for (int i = 0; i < initializer.numLevels(); i++) {
|
||||
int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i));
|
||||
int[] initializerGraphNodesOnLevel =
|
||||
mapArrayAndSort(nodesIteratorToArray(h.getNodesOnLevel(i)), oldToNewOrdMap);
|
||||
mapArrayAndSort(nodesIteratorToArray(initializer.getNodesOnLevel(i)), newOrdinals);
|
||||
int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel);
|
||||
assertEquals(initializerGraphNodesOnLevel.length, overlap);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertGraphInitializedFromGraph(
|
||||
HnswGraph g, HnswGraph initializer, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||
HnswGraph g, HnswGraph initializer, int[] newOrdinals) throws IOException {
|
||||
assertEquals(
|
||||
"the number of levels in the graphs are different!",
|
||||
initializer.numLevels(),
|
||||
|
@ -550,44 +621,23 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
// new ordinal mapping
|
||||
assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size());
|
||||
|
||||
// assert that all the node from initializer graph can be found in the new graph and
|
||||
// the neighbors from the old graph are successfully transferred to the new graph
|
||||
// assert that the neighbors from the old graph are successfully transferred to the new graph
|
||||
for (int level = 0; level < g.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
g.seek(level, oldToNewOrdMap.get(node));
|
||||
g.seek(level, newOrdinals[node]);
|
||||
initializer.seek(level, node);
|
||||
assertEquals(
|
||||
"arcs differ for node " + node,
|
||||
getNeighborNodes(g),
|
||||
getNeighborNodes(initializer).stream()
|
||||
.map(oldToNewOrdMap::get)
|
||||
.map(n -> newOrdinals[n])
|
||||
.collect(Collectors.toSet()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Map<Integer, Integer> createOffsetOrdinalMap(
|
||||
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
|
||||
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
|
||||
// vector values
|
||||
// before the docIdOffset
|
||||
int ordinalOffset = 0;
|
||||
while (totalVectorValues.nextDoc() < docIdOffset) {
|
||||
ordinalOffset++;
|
||||
}
|
||||
|
||||
Map<Integer, Integer> offsetOrdinalMap = new HashMap<>();
|
||||
for (int curr = 0;
|
||||
totalVectorValues.docID() < docIdOffset + docIdSize;
|
||||
totalVectorValues.nextDoc()) {
|
||||
offsetOrdinalMap.put(curr, ordinalOffset + curr++);
|
||||
}
|
||||
|
||||
return offsetOrdinalMap;
|
||||
}
|
||||
|
||||
private int[] nodesIteratorToArray(NodesIterator nodesIterator) {
|
||||
int[] arr = new int[nodesIterator.size()];
|
||||
int i = 0;
|
||||
|
@ -597,15 +647,35 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return arr;
|
||||
}
|
||||
|
||||
private int[] mapArrayAndSort(int[] arr, Map<Integer, Integer> map) {
|
||||
private int[] mapArrayAndSort(int[] arr, int[] offset) {
|
||||
int[] mappedA = new int[arr.length];
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
mappedA[i] = map.get(arr[i]);
|
||||
mappedA[i] = offset[arr[i]];
|
||||
}
|
||||
Arrays.sort(mappedA);
|
||||
return mappedA;
|
||||
}
|
||||
|
||||
private int[] createOffsetOrdinalMap(
|
||||
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
|
||||
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
|
||||
// vector values
|
||||
// before the docIdOffset
|
||||
int ordinalOffset = 0;
|
||||
while (totalVectorValues.nextDoc() < docIdOffset) {
|
||||
ordinalOffset++;
|
||||
}
|
||||
int[] offsetOrdinalMap = new int[docIdSize];
|
||||
|
||||
for (int curr = 0;
|
||||
totalVectorValues.docID() < docIdOffset + docIdSize;
|
||||
totalVectorValues.nextDoc()) {
|
||||
offsetOrdinalMap[curr] = ordinalOffset + curr++;
|
||||
}
|
||||
|
||||
return offsetOrdinalMap;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
|
@ -1083,7 +1153,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
static float[][] createRandomFloatVectors(int size, int dimension, Random random) {
|
||||
float[][] vectors = new float[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
for (int offset = 0; offset < size; offset++) {
|
||||
vectors[offset] = randomVector(random, dimension);
|
||||
}
|
||||
return vectors;
|
||||
|
@ -1091,7 +1161,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
static byte[][] createRandomByteVectors(int size, int dimension, Random random) {
|
||||
byte[][] vectors = new byte[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
for (int offset = 0; offset < size; offset++) {
|
||||
vectors[offset] = randomVector8(random, dimension);
|
||||
}
|
||||
return vectors;
|
||||
|
|
|
@ -24,7 +24,12 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]> {
|
|||
private final byte[] scratch;
|
||||
|
||||
static MockByteVectorValues fromValues(byte[][] values) {
|
||||
int dimension = values[0].length;
|
||||
byte[] firstNonNull = null;
|
||||
int j = 0;
|
||||
while (firstNonNull == null && j < values.length) {
|
||||
firstNonNull = values[j++];
|
||||
}
|
||||
int dimension = firstNonNull.length;
|
||||
int maxDoc = values.length;
|
||||
byte[][] denseValues = new byte[maxDoc][];
|
||||
int count = 0;
|
||||
|
|
|
@ -24,7 +24,12 @@ class MockVectorValues extends AbstractMockVectorValues<float[]> {
|
|||
private final float[] scratch;
|
||||
|
||||
static MockVectorValues fromValues(float[][] values) {
|
||||
int dimension = values[0].length;
|
||||
float[] firstNonNull = null;
|
||||
int j = 0;
|
||||
while (firstNonNull == null && j < values.length) {
|
||||
firstNonNull = values[j++];
|
||||
}
|
||||
int dimension = firstNonNull.length;
|
||||
int maxDoc = values.length;
|
||||
float[][] denseValues = new float[maxDoc][];
|
||||
int count = 0;
|
||||
|
|
Loading…
Reference in New Issue