Extract the hnsw graph merging from being part of the vector writer (#12657)

While working on the quantization codec & thinking about how merging will evolve, it became clearer that having merging attached directly to the vector writer is weird.

I extracted it out to its own class and removed the "initializedNodes" logic from the base class builder.

Also, there was on other refactoring around grabbing sorted nodes from the neighbor iterator, I just moved that static method so its not attached to the writer (as all bwc writers need it and all future HNSW writers will as well).
This commit is contained in:
Benjamin Trent 2023-10-17 13:45:25 -04:00 committed by GitHub
parent 218eddec70
commit ea272d0eda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 511 additions and 357 deletions

View File

@ -226,7 +226,8 @@ Build
Other Other
--------------------- ---------------------
(No changes)
* GITHUB#12657: Internal refactor of HNSW graph merging (Ben Trent).
======================== Lucene 9.8.0 ======================= ======================== Lucene 9.8.0 =======================

View File

@ -25,7 +25,6 @@ import java.nio.ByteOrder;
import java.util.Arrays; import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
@ -37,6 +36,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/** /**
@ -227,7 +227,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
} else { } else {
meta.writeInt(graph.numLevels()); meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) { if (level > 0) {
for (int node : sortedNodes) { for (int node : sortedNodes) {
@ -256,7 +256,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
// write vectors' neighbours on each level into the vectorIndex file // write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size(); int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) { for (int node : sortedNodes) {
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node); Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size(); int size = neighbors.size();

View File

@ -27,7 +27,6 @@ import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
@ -39,6 +38,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph; import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
@ -261,7 +261,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
} else { } else {
meta.writeInt(graph.numLevels()); meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) { if (level > 0) {
for (int node : sortedNodes) { for (int node : sortedNodes) {
@ -289,7 +289,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
int countOnLevel0 = graph.size(); int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M; int maxConnOnLevel = level == 0 ? (M * 2) : M;
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) { for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node); NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size(); int size = neighbors.size();

View File

@ -30,7 +30,6 @@ import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
@ -477,7 +476,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
int countOnLevel0 = graph.size(); int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M; int maxConnOnLevel = level == 0 ? (M * 2) : M;
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) { for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node); NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size(); int size = neighbors.size();
@ -565,7 +564,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} else { } else {
meta.writeInt(graph.numLevels()); meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) { if (level > 0) {
for (int node : sortedNodes) { for (int node : sortedNodes) {

View File

@ -25,16 +25,10 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.*; import org.apache.lucene.index.*;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
@ -423,51 +417,49 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
OnHeapHnswGraph graph = null; OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null; int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) { if (docsWithField.cardinality() != 0) {
// build graph final RandomVectorScorerSupplier scorerSupplier;
int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo); switch (fieldInfo.getVectorEncoding()) {
graph = case BYTE:
switch (fieldInfo.getVectorEncoding()) { scorerSupplier =
case BYTE -> { RandomVectorScorerSupplier.createBytes(
OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
new OffHeapByteVectorValues.DenseOffHeapVectorValues( new OffHeapByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(), fieldInfo.getVectorDimension(),
docsWithField.cardinality(), docsWithField.cardinality(),
vectorDataInput, vectorDataInput,
byteSize); byteSize),
RandomVectorScorerSupplier scorerSupplier = fieldInfo.getVectorSimilarityFunction());
RandomVectorScorerSupplier.createBytes( break;
vectorValues, fieldInfo.getVectorSimilarityFunction()); case FLOAT32:
HnswGraphBuilder hnswGraphBuilder = scorerSupplier =
createHnswGraphBuilder( RandomVectorScorerSupplier.createFloats(
mergeState,
fieldInfo,
scorerSupplier,
initializerIndex,
vectorValues.size());
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.size());
}
case FLOAT32 -> {
OffHeapFloatVectorValues.DenseOffHeapVectorValues vectorValues =
new OffHeapFloatVectorValues.DenseOffHeapVectorValues( new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(), fieldInfo.getVectorDimension(),
docsWithField.cardinality(), docsWithField.cardinality(),
vectorDataInput, vectorDataInput,
byteSize); byteSize),
RandomVectorScorerSupplier scorerSupplier = fieldInfo.getVectorSimilarityFunction());
RandomVectorScorerSupplier.createFloats( break;
vectorValues, fieldInfo.getVectorSimilarityFunction()); default:
HnswGraphBuilder hnswGraphBuilder = throw new IllegalArgumentException(
createHnswGraphBuilder( "Unsupported vector encoding: " + fieldInfo.getVectorEncoding());
mergeState, }
fieldInfo, // build graph
scorerSupplier, IncrementalHnswGraphMerger merger =
initializerIndex, new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
vectorValues.size()); for (int i = 0; i < mergeState.liveDocs.length; i++) {
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); merger.addReader(
yield hnswGraphBuilder.build(vectorValues.size()); mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
} }
}; DocIdSetIterator mergedVectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph); vectorIndexNodeOffsets = writeGraph(graph);
} }
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@ -494,185 +486,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
} }
} }
private HnswGraphBuilder createHnswGraphBuilder(
MergeState mergeState,
FieldInfo fieldInfo,
RandomVectorScorerSupplier scorerSupplier,
int initializerIndex,
int graphSize)
throws IOException {
if (initializerIndex == -1) {
return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize);
}
HnswGraph initializerGraph =
getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
return HnswGraphBuilder.create(
scorerSupplier,
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
ordinalMapper,
graphSize);
}
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
throws IOException {
// Find the KnnVectorReader with the most docs that meets the following criteria:
// 1. Does not contain any deleted docs
// 2. Is a HnswGraphProvider/PerFieldKnnVectorReader
// If no readers exist that meet this criteria, return -1. If they do, return their index in
// merge state
int maxCandidateVectorCount = 0;
int initializerIndex = -1;
for (int i = 0; i < mergeState.liveDocs.length; i++) {
KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
if (mergeState.knnVectorsReaders[i]
instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
}
if (!allMatch(mergeState.liveDocs[i])
|| !(currKnnVectorsReader instanceof HnswGraphProvider)) {
continue;
}
int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
ByteVectorValues byteVectorValues =
currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
continue;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
FloatVectorValues vectorValues =
currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
}
candidateVectorCount = vectorValues.size();
}
}
if (candidateVectorCount > maxCandidateVectorCount) {
maxCandidateVectorCount = candidateVectorCount;
initializerIndex = i;
}
}
return initializerIndex;
}
private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
throws IOException {
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
&& perFieldReader.getFieldReader(fieldName) instanceof HnswGraphProvider fieldReader) {
return fieldReader.getGraph(fieldName);
}
if (knnVectorsReader instanceof HnswGraphProvider provider) {
return provider.getGraph(fieldName);
}
// We should not reach here because knnVectorsReader's type is checked in
// selectGraphForInitialization
throw new IllegalArgumentException(
"Invalid KnnVectorsReader type for field: "
+ fieldName
+ ". Must be Lucene95HnswVectorsReader or newer");
}
private Map<Integer, Integer> getOldToNewOrdinalMap(
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {
DocIdSetIterator initializerIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
case FLOAT32 -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
}
MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];
Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
int oldOrd = 0;
int maxNewDocID = -1;
for (int oldId = initializerIterator.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerIterator.nextDoc()) {
if (isCurrentVectorNull(initializerIterator)) {
continue;
}
int newId = initializerDocMap.get(oldId);
maxNewDocID = Math.max(newId, maxNewDocID);
newIdToOldOrdinal.put(newId, oldOrd);
oldOrd++;
}
if (maxNewDocID == -1) {
return Collections.emptyMap();
}
Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
DocIdSetIterator vectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 -> vectorIterator =
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
int newOrd = 0;
for (int newDocId = vectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = vectorIterator.nextDoc()) {
if (isCurrentVectorNull(vectorIterator)) {
continue;
}
if (newIdToOldOrdinal.containsKey(newDocId)) {
oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
}
newOrd++;
}
return oldToNewOrdinalMap;
}
private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
if (docIdSetIterator instanceof FloatVectorValues) {
return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
}
if (docIdSetIterator instanceof ByteVectorValues) {
return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
}
return true;
}
private boolean allMatch(Bits bits) {
if (bits == null) {
return true;
}
for (int i = 0; i < bits.length(); i++) {
if (!bits.get(i)) {
return false;
}
}
return true;
}
/** /**
* @param graph Write the graph in a compressed format * @param graph Write the graph in a compressed format
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets. * @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
@ -684,7 +497,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
int countOnLevel0 = graph.size(); int countOnLevel0 = graph.size();
int[][] offsets = new int[graph.numLevels()][]; int[][] offsets = new int[graph.numLevels()][];
for (int level = 0; level < graph.numLevels(); level++) { for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level)); int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
offsets[level] = new int[sortedNodes.length]; offsets[level] = new int[sortedNodes.length];
int nodeOffsetId = 0; int nodeOffsetId = 0;
for (int node : sortedNodes) { for (int node : sortedNodes) {
@ -712,15 +525,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
return offsets; return offsets;
} }
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
int[] sortedNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
sortedNodes[n] = nodesOnLevel.nextInt();
}
Arrays.sort(sortedNodes);
return sortedNodes;
}
private void writeMeta( private void writeMeta(
FieldInfo field, FieldInfo field,
int maxDoc, int maxDoc,

View File

@ -20,6 +20,7 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Iterator; import java.util.Iterator;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
@ -152,6 +153,15 @@ public abstract class HnswGraph {
* @return The number of integers written to `dest` * @return The number of integers written to `dest`
*/ */
public abstract int consume(int[] dest); public abstract int consume(int[] dest);
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
int[] sortedNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
sortedNodes[n] = nodesOnLevel.nextInt();
}
Arrays.sort(sortedNodes);
return sortedNodes;
}
} }
/** NodesIterator that accepts nodes as an integer array. */ /** NodesIterator that accepts nodes as an integer array. */

View File

@ -18,14 +18,10 @@
package org.apache.lucene.util.hnsw; package org.apache.lucene.util.hnsw;
import static java.lang.Math.log; import static java.lang.Math.log;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
import java.util.HashSet;
import java.util.Locale; import java.util.Locale;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.SplittableRandom; import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
@ -37,7 +33,7 @@ import org.apache.lucene.util.InfoStream;
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyper-parameters. * hyper-parameters.
*/ */
public final class HnswGraphBuilder { public class HnswGraphBuilder {
/** Default number of maximum connections per node */ /** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16; public static final int DEFAULT_MAX_CONN = 16;
@ -67,12 +63,10 @@ public final class HnswGraphBuilder {
private final GraphBuilderKnnCollector private final GraphBuilderKnnCollector
beamCandidates; // for levels of graph where we add the node beamCandidates; // for levels of graph where we add the node
final OnHeapHnswGraph hnsw; protected final OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault(); private InfoStream infoStream = InfoStream.getDefault();
private final Set<Integer> initializedNodes;
public static HnswGraphBuilder create( public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
throws IOException { throws IOException {
@ -80,23 +74,9 @@ public final class HnswGraphBuilder {
} }
public static HnswGraphBuilder create( public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) { RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize)
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
}
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
HnswGraph initializerGraph,
Map<Integer, Integer> oldToNewOrdinalMap,
int graphSize)
throws IOException { throws IOException {
HnswGraphBuilder hnswGraphBuilder = return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
return hnswGraphBuilder;
} }
/** /**
@ -111,8 +91,31 @@ public final class HnswGraphBuilder {
* to ensure repeatable construction. * to ensure repeatable construction.
* @param graphSize size of graph, if unknown, pass in -1 * @param graphSize size of graph, if unknown, pass in -1
*/ */
private HnswGraphBuilder( protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) { RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize)
throws IOException {
this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
}
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
*
* @param scorerSupplier a supplier to create vector scorer from ordinals.
* @param M graph fanout parameter used to calculate the maximum number of connections a node
* can have M on upper layers, and M * 2 on the lowest level.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
* @param hnsw the graph to build, can be previously initialized
*/
protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw)
throws IOException {
if (M <= 0) { if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive"); throw new IllegalArgumentException("maxConn must be positive");
} }
@ -125,7 +128,7 @@ public final class HnswGraphBuilder {
// normalization factor for level generation; currently not configurable // normalization factor for level generation; currently not configurable
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed); this.random = new SplittableRandom(seed);
this.hnsw = new OnHeapHnswGraph(M, graphSize); this.hnsw = hnsw;
this.graphSearcher = this.graphSearcher =
new HnswGraphSearcher( new HnswGraphSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size())); new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
@ -133,7 +136,6 @@ public final class HnswGraphBuilder {
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false); scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
entryCandidates = new GraphBuilderKnnCollector(1); entryCandidates = new GraphBuilderKnnCollector(1);
beamCandidates = new GraphBuilderKnnCollector(beamWidth); beamCandidates = new GraphBuilderKnnCollector(beamWidth);
this.initializedNodes = new HashSet<>();
} }
/** /**
@ -149,45 +151,6 @@ public final class HnswGraphBuilder {
return hnsw; return hnsw;
} }
/**
* Initializes the graph of this builder. Transfers the nodes and their neighbors from the
* initializer graph into the graph being produced by this builder, mapping ordinals from the
* initializer graph to their new ordinals in this builder's graph. The builder's graph must be
* empty before calling this method.
*
* @param initializerGraph graph used for initialization
* @param oldToNewOrdinalMap map for converting from ordinals in the initializerGraph to this
* builder's graph
*/
private void initializeFromGraph(
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
assert hnsw.size() == 0;
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
while (it.hasNext()) {
int oldOrd = it.nextInt();
int newOrd = oldToNewOrdinalMap.get(oldOrd);
hnsw.addNode(level, newOrd);
if (level == 0) {
initializedNodes.add(newOrd);
}
NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
initializerGraph.seek(level, oldOrd);
for (int oldNeighbor = initializerGraph.nextNeighbor();
oldNeighbor != NO_MORE_DOCS;
oldNeighbor = initializerGraph.nextNeighbor()) {
int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
// we will compute these scores later when we need to pop out the non-diverse nodes
newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
}
}
}
}
/** Set info-stream to output debugging information * */ /** Set info-stream to output debugging information * */
public void setInfoStream(InfoStream infoStream) { public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream; this.infoStream = infoStream;
@ -200,9 +163,6 @@ public final class HnswGraphBuilder {
private void addVectors(int maxOrd) throws IOException { private void addVectors(int maxOrd) throws IOException {
long start = System.nanoTime(), t = start; long start = System.nanoTime(), t = start;
for (int node = 0; node < maxOrd; node++) { for (int node = 0; node < maxOrd; node++) {
if (initializedNodes.contains(node)) {
continue;
}
addGraphNode(node); addGraphNode(node);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t); t = printGraphBuildStatus(node, start, t);

View File

@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Map;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.FixedBitSet;
/**
* This selects the biggest Hnsw graph from the provided merge state and initializes a new
* HnswGraphBuilder with that graph as a starting point.
*
* @lucene.experimental
*/
public class IncrementalHnswGraphMerger {
private KnnVectorsReader initReader;
private MergeState.DocMap initDocMap;
private int initGraphSize;
private final FieldInfo fieldInfo;
private final RandomVectorScorerSupplier scorerSupplier;
private final int M;
private final int beamWidth;
/**
* @param fieldInfo FieldInfo for the field being merged
*/
public IncrementalHnswGraphMerger(
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth) {
this.fieldInfo = fieldInfo;
this.scorerSupplier = scorerSupplier;
this.M = M;
this.beamWidth = beamWidth;
}
/**
* Adds a reader to the graph merger if it meets the following criteria: 1. Does not contain any
* deleted docs 2. Is a HnswGraphProvider/PerFieldKnnVectorReader 3. Has the most docs of any
* previous reader that met the above criteria
*
* @param reader KnnVectorsReader to add to the merger
* @param docMap MergeState.DocMap for the reader
* @param liveDocs Bits representing live docs, can be null
* @return this
* @throws IOException If an error occurs while reading from the merge state
*/
public IncrementalHnswGraphMerger addReader(
KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException {
KnnVectorsReader currKnnVectorsReader = reader;
if (reader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
}
if (!(currKnnVectorsReader instanceof HnswGraphProvider) || !noDeletes(liveDocs)) {
return this;
}
int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
ByteVectorValues byteVectorValues =
currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
return this;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
FloatVectorValues vectorValues = currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
return this;
}
candidateVectorCount = vectorValues.size();
}
}
if (candidateVectorCount > initGraphSize) {
initReader = currKnnVectorsReader;
initDocMap = docMap;
initGraphSize = candidateVectorCount;
}
return this;
}
/**
* Builds a new HnswGraphBuilder using the biggest graph from the merge state as a starting point.
* If no valid readers were added to the merge state, a new graph is created.
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @return HnswGraphBuilder
* @throws IOException If an error occurs while reading from the merge state
*/
public HnswGraphBuilder createBuilder(DocIdSetIterator mergedVectorIterator) throws IOException {
if (initReader == null) {
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
}
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
final int numVectors = Math.toIntExact(mergedVectorIterator.cost());
BitSet initializedNodes = new FixedBitSet(numVectors + 1);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
return InitializedHnswGraphBuilder.fromGraph(
scorerSupplier,
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
oldToNewOrdinalMap,
initializedNodes,
numVectors);
}
/**
* Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors
* in the newly merged segment.
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @param initializedNodes track what nodes have been initialized
* @return the mapping from old ordinals to new ordinals
* @throws IOException If an error occurs while reading from the merge state
*/
private int[] getNewOrdMapping(DocIdSetIterator mergedVectorIterator, BitSet initializedNodes)
throws IOException {
DocIdSetIterator initializerIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name);
case FLOAT32 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name);
}
Map<Integer, Integer> newIdToOldOrdinal = CollectionUtil.newHashMap(initGraphSize);
int oldOrd = 0;
int maxNewDocID = -1;
for (int oldId = initializerIterator.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerIterator.nextDoc()) {
int newId = initDocMap.get(oldId);
maxNewDocID = Math.max(newId, maxNewDocID);
newIdToOldOrdinal.put(newId, oldOrd);
oldOrd++;
}
if (maxNewDocID == -1) {
return new int[0];
}
final int[] oldToNewOrdinalMap = new int[initGraphSize];
int newOrd = 0;
for (int newDocId = mergedVectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = mergedVectorIterator.nextDoc()) {
if (newIdToOldOrdinal.containsKey(newDocId)) {
initializedNodes.set(newOrd);
oldToNewOrdinalMap[newIdToOldOrdinal.get(newDocId)] = newOrd;
}
newOrd++;
}
return oldToNewOrdinalMap;
}
private static boolean noDeletes(Bits liveDocs) {
if (liveDocs == null) {
return true;
}
for (int i = 0; i < liveDocs.length(); i++) {
if (!liveDocs.get(i)) {
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.util.BitSet;
/**
* This creates a graph builder that is initialized with the provided HnswGraph. This is useful for
* merging HnswGraphs from multiple segments.
*
* @lucene.experimental
*/
public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
/**
* Create a new HnswGraphBuilder that is initialized with the provided HnswGraph.
*
* @param scorerSupplier the scorer to use for vectors
* @param M the number of connections to keep per node
* @param beamWidth the number of nodes to explore in the search
* @param seed the seed for the random number generator
* @param initializerGraph the graph to initialize the new graph builder
* @param newOrdMap a mapping from the old node ordinal to the new node ordinal
* @param initializedNodes a bitset of nodes that are already initialized in the initializerGraph
* @param totalNumberOfVectors the total number of vectors in the new graph, this should include
* all vectors expected to be added to the graph in the future
* @return a new HnswGraphBuilder that is initialized with the provided HnswGraph
* @throws IOException when reading the graph fails
*/
public static InitializedHnswGraphBuilder fromGraph(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
HnswGraph initializerGraph,
int[] newOrdMap,
BitSet initializedNodes,
int totalNumberOfVectors)
throws IOException {
OnHeapHnswGraph hnsw = new OnHeapHnswGraph(M, totalNumberOfVectors);
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
while (it.hasNext()) {
int oldOrd = it.nextInt();
int newOrd = newOrdMap[oldOrd];
hnsw.addNode(level, newOrd);
NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd);
initializerGraph.seek(level, oldOrd);
for (int oldNeighbor = initializerGraph.nextNeighbor();
oldNeighbor != NO_MORE_DOCS;
oldNeighbor = initializerGraph.nextNeighbor()) {
int newNeighbor = newOrdMap[oldNeighbor];
// we will compute these scores later when we need to pop out the non-diverse nodes
newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
}
}
}
return new InitializedHnswGraphBuilder(
scorerSupplier, M, beamWidth, seed, hnsw, initializedNodes);
}
private final BitSet initializedNodes;
public InitializedHnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph initializedGraph,
BitSet initializedNodes)
throws IOException {
super(scorerSupplier, M, beamWidth, seed, initializedGraph);
this.initializedNodes = initializedNodes;
}
@Override
public void addGraphNode(int node) throws IOException {
if (initializedNodes.get(node)) {
return;
}
super.addGraphNode(node);
}
}

View File

@ -26,11 +26,9 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -49,6 +47,7 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField; import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
@ -59,8 +58,10 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
@ -132,6 +133,73 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}; };
} }
// Tests writing segments of various sizes and merging to ensure there are no errors
// in the HNSW graph merging logic.
public void testRandomReadWriteAndMerge() throws IOException {
int dim = random().nextInt(100) + 1;
int[] segmentSizes =
new int[] {random().nextInt(20) + 1, random().nextInt(10) + 30, random().nextInt(10) + 20};
// Randomly delete vector documents
boolean[] addDeletes =
new boolean[] {random().nextBoolean(), random().nextBoolean(), random().nextBoolean()};
// randomly index other documents besides vector docs
boolean[] isSparse =
new boolean[] {random().nextBoolean(), random().nextBoolean(), random().nextBoolean()};
int numVectors = segmentSizes[0] + segmentSizes[1] + segmentSizes[2];
int M = random().nextInt(4) + 2;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
AbstractMockVectorValues<T> vectors = vectorValues(numVectors, dim);
HnswGraphBuilder.randSeed = seed;
try (Directory dir = newDirectory()) {
IndexWriterConfig iwc =
new IndexWriterConfig()
.setCodec(
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, beamWidth);
}
})
// set a random merge policy
.setMergePolicy(newMergePolicy(random()));
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
for (int i = 0; i < segmentSizes.length; i++) {
int size = segmentSizes[i];
while (vectors.nextDoc() < size) {
if (isSparse[i] && random().nextBoolean()) {
int d = random().nextInt(10) + 1;
for (int j = 0; j < d; j++) {
Document doc = new Document();
iw.addDocument(doc);
}
}
Document doc = new Document();
doc.add(knnVectorField("field", vectors.vectorValue(), similarityFunction));
doc.add(new StringField("id", Integer.toString(vectors.docID()), Field.Store.NO));
iw.addDocument(doc);
}
iw.commit();
if (addDeletes[i] && size > 1) {
for (int d = 0; d < size; d += random().nextInt(5) + 1) {
iw.deleteDocuments(new Term("id", Integer.toString(d)));
}
iw.commit();
}
}
iw.commit();
iw.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
assertEquals(dim, values.dimension());
}
}
}
}
// test writing out and reading in a graph gives the expected graph // test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException { public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1; int dim = random().nextInt(100) + 1;
@ -469,20 +537,21 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
AbstractMockVectorValues<T> finalVectorValues = AbstractMockVectorValues<T> finalVectorValues =
vectorValues(totalSize, dim, initializerVectors, docIdOffset); vectorValues(totalSize, dim, initializerVectors, docIdOffset);
int[] initializerOrdMap =
Map<Integer, Integer> initializerOrdMap =
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder = HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create( InitializedHnswGraphBuilder.fromGraph(
finalscorerSupplier, finalscorerSupplier,
10, 10,
30, 30,
seed, seed,
initializerGraph, initializerGraph,
initializerOrdMap, initializerOrdMap,
finalVectorValues.size()); BitSet.of(
DocIdSetIterator.range(docIdOffset, initializerSize + docIdOffset), totalSize + 1),
totalSize);
// When offset is 0, the graphs should be identical before vectors are added // When offset is 0, the graphs should be identical before vectors are added
assertGraphEqual(initializerGraph, finalBuilder.getGraph()); assertGraphEqual(initializerGraph, finalBuilder.getGraph());
@ -506,19 +575,21 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
AbstractMockVectorValues<T> finalVectorValues = AbstractMockVectorValues<T> finalVectorValues =
vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset); vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
Map<Integer, Integer> initializerOrdMap = int[] initializerOrdMap =
createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset); createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder = HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create( InitializedHnswGraphBuilder.fromGraph(
finalscorerSupplier, finalscorerSupplier,
10, 10,
30, 30,
seed, seed,
initializerGraph, initializerGraph,
initializerOrdMap, initializerOrdMap,
finalVectorValues.size()); BitSet.of(
DocIdSetIterator.range(docIdOffset, initializerSize + docIdOffset), totalSize + 1),
totalSize);
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap); assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
@ -528,19 +599,19 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
} }
private void assertGraphContainsGraph( private void assertGraphContainsGraph(HnswGraph g, HnswGraph initializer, int[] newOrdinals)
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException { throws IOException {
for (int i = 0; i < h.numLevels(); i++) { for (int i = 0; i < initializer.numLevels(); i++) {
int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i)); int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i));
int[] initializerGraphNodesOnLevel = int[] initializerGraphNodesOnLevel =
mapArrayAndSort(nodesIteratorToArray(h.getNodesOnLevel(i)), oldToNewOrdMap); mapArrayAndSort(nodesIteratorToArray(initializer.getNodesOnLevel(i)), newOrdinals);
int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel); int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel);
assertEquals(initializerGraphNodesOnLevel.length, overlap); assertEquals(initializerGraphNodesOnLevel.length, overlap);
} }
} }
private void assertGraphInitializedFromGraph( private void assertGraphInitializedFromGraph(
HnswGraph g, HnswGraph initializer, Map<Integer, Integer> oldToNewOrdMap) throws IOException { HnswGraph g, HnswGraph initializer, int[] newOrdinals) throws IOException {
assertEquals( assertEquals(
"the number of levels in the graphs are different!", "the number of levels in the graphs are different!",
initializer.numLevels(), initializer.numLevels(),
@ -550,44 +621,23 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
// new ordinal mapping // new ordinal mapping
assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size()); assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size());
// assert that all the node from initializer graph can be found in the new graph and // assert that the neighbors from the old graph are successfully transferred to the new graph
// the neighbors from the old graph are successfully transferred to the new graph
for (int level = 0; level < g.numLevels(); level++) { for (int level = 0; level < g.numLevels(); level++) {
NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level); NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) { while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt(); int node = nodesOnLevel.nextInt();
g.seek(level, oldToNewOrdMap.get(node)); g.seek(level, newOrdinals[node]);
initializer.seek(level, node); initializer.seek(level, node);
assertEquals( assertEquals(
"arcs differ for node " + node, "arcs differ for node " + node,
getNeighborNodes(g), getNeighborNodes(g),
getNeighborNodes(initializer).stream() getNeighborNodes(initializer).stream()
.map(oldToNewOrdMap::get) .map(n -> newOrdinals[n])
.collect(Collectors.toSet())); .collect(Collectors.toSet()));
} }
} }
} }
private Map<Integer, Integer> createOffsetOrdinalMap(
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
// vector values
// before the docIdOffset
int ordinalOffset = 0;
while (totalVectorValues.nextDoc() < docIdOffset) {
ordinalOffset++;
}
Map<Integer, Integer> offsetOrdinalMap = new HashMap<>();
for (int curr = 0;
totalVectorValues.docID() < docIdOffset + docIdSize;
totalVectorValues.nextDoc()) {
offsetOrdinalMap.put(curr, ordinalOffset + curr++);
}
return offsetOrdinalMap;
}
private int[] nodesIteratorToArray(NodesIterator nodesIterator) { private int[] nodesIteratorToArray(NodesIterator nodesIterator) {
int[] arr = new int[nodesIterator.size()]; int[] arr = new int[nodesIterator.size()];
int i = 0; int i = 0;
@ -597,15 +647,35 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return arr; return arr;
} }
private int[] mapArrayAndSort(int[] arr, Map<Integer, Integer> map) { private int[] mapArrayAndSort(int[] arr, int[] offset) {
int[] mappedA = new int[arr.length]; int[] mappedA = new int[arr.length];
for (int i = 0; i < arr.length; i++) { for (int i = 0; i < arr.length; i++) {
mappedA[i] = map.get(arr[i]); mappedA[i] = offset[arr[i]];
} }
Arrays.sort(mappedA); Arrays.sort(mappedA);
return mappedA; return mappedA;
} }
private int[] createOffsetOrdinalMap(
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
// vector values
// before the docIdOffset
int ordinalOffset = 0;
while (totalVectorValues.nextDoc() < docIdOffset) {
ordinalOffset++;
}
int[] offsetOrdinalMap = new int[docIdSize];
for (int curr = 0;
totalVectorValues.docID() < docIdOffset + docIdSize;
totalVectorValues.nextDoc()) {
offsetOrdinalMap[curr] = ordinalOffset + curr++;
}
return offsetOrdinalMap;
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testVisitedLimit() throws IOException { public void testVisitedLimit() throws IOException {
int nDoc = 500; int nDoc = 500;
@ -1083,7 +1153,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
static float[][] createRandomFloatVectors(int size, int dimension, Random random) { static float[][] createRandomFloatVectors(int size, int dimension, Random random) {
float[][] vectors = new float[size][]; float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { for (int offset = 0; offset < size; offset++) {
vectors[offset] = randomVector(random, dimension); vectors[offset] = randomVector(random, dimension);
} }
return vectors; return vectors;
@ -1091,7 +1161,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
static byte[][] createRandomByteVectors(int size, int dimension, Random random) { static byte[][] createRandomByteVectors(int size, int dimension, Random random) {
byte[][] vectors = new byte[size][]; byte[][] vectors = new byte[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { for (int offset = 0; offset < size; offset++) {
vectors[offset] = randomVector8(random, dimension); vectors[offset] = randomVector8(random, dimension);
} }
return vectors; return vectors;

View File

@ -24,7 +24,12 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]> {
private final byte[] scratch; private final byte[] scratch;
static MockByteVectorValues fromValues(byte[][] values) { static MockByteVectorValues fromValues(byte[][] values) {
int dimension = values[0].length; byte[] firstNonNull = null;
int j = 0;
while (firstNonNull == null && j < values.length) {
firstNonNull = values[j++];
}
int dimension = firstNonNull.length;
int maxDoc = values.length; int maxDoc = values.length;
byte[][] denseValues = new byte[maxDoc][]; byte[][] denseValues = new byte[maxDoc][];
int count = 0; int count = 0;

View File

@ -24,7 +24,12 @@ class MockVectorValues extends AbstractMockVectorValues<float[]> {
private final float[] scratch; private final float[] scratch;
static MockVectorValues fromValues(float[][] values) { static MockVectorValues fromValues(float[][] values) {
int dimension = values[0].length; float[] firstNonNull = null;
int j = 0;
while (firstNonNull == null && j < values.length) {
firstNonNull = values[j++];
}
int dimension = firstNonNull.length;
int maxDoc = values.length; int maxDoc = values.length;
float[][] denseValues = new float[maxDoc][]; float[][] denseValues = new float[maxDoc][];
int count = 0; int count = 0;