mirror of https://github.com/apache/lucene.git
Optimize OnHeapHnswGraph's data structure (#12651)
make the internal graph representation a 2d array
This commit is contained in:
parent
2482f7688b
commit
a1cf22e6a9
|
@ -195,6 +195,8 @@ Optimizations
|
||||||
|
|
||||||
* GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums.
|
* GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums.
|
||||||
(Adrien Grand)
|
(Adrien Grand)
|
||||||
|
|
||||||
|
* GITHUB#12651: Use 2d array for OnHeapHnswGraph representation. (Patrick Zhai)
|
||||||
|
|
||||||
Changes in runtime behavior
|
Changes in runtime behavior
|
||||||
---------------------
|
---------------------
|
||||||
|
|
|
@ -56,7 +56,11 @@ public class Word2VecSynonymProvider {
|
||||||
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
|
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed);
|
scorerSupplier,
|
||||||
|
DEFAULT_MAX_CONN,
|
||||||
|
DEFAULT_BEAM_WIDTH,
|
||||||
|
HnswGraphBuilder.randSeed,
|
||||||
|
word2VecModel.size());
|
||||||
this.hnswGraph = builder.build(word2VecModel.size());
|
this.hnswGraph = builder.build(word2VecModel.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -438,7 +438,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
RandomVectorScorerSupplier.createBytes(
|
RandomVectorScorerSupplier.createBytes(
|
||||||
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
||||||
HnswGraphBuilder hnswGraphBuilder =
|
HnswGraphBuilder hnswGraphBuilder =
|
||||||
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
|
createHnswGraphBuilder(
|
||||||
|
mergeState,
|
||||||
|
fieldInfo,
|
||||||
|
scorerSupplier,
|
||||||
|
initializerIndex,
|
||||||
|
vectorValues.size());
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
yield hnswGraphBuilder.build(vectorValues.size());
|
yield hnswGraphBuilder.build(vectorValues.size());
|
||||||
}
|
}
|
||||||
|
@ -453,7 +458,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
RandomVectorScorerSupplier.createFloats(
|
RandomVectorScorerSupplier.createFloats(
|
||||||
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
||||||
HnswGraphBuilder hnswGraphBuilder =
|
HnswGraphBuilder hnswGraphBuilder =
|
||||||
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
|
createHnswGraphBuilder(
|
||||||
|
mergeState,
|
||||||
|
fieldInfo,
|
||||||
|
scorerSupplier,
|
||||||
|
initializerIndex,
|
||||||
|
vectorValues.size());
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
yield hnswGraphBuilder.build(vectorValues.size());
|
yield hnswGraphBuilder.build(vectorValues.size());
|
||||||
}
|
}
|
||||||
|
@ -488,10 +498,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
MergeState mergeState,
|
MergeState mergeState,
|
||||||
FieldInfo fieldInfo,
|
FieldInfo fieldInfo,
|
||||||
RandomVectorScorerSupplier scorerSupplier,
|
RandomVectorScorerSupplier scorerSupplier,
|
||||||
int initializerIndex)
|
int initializerIndex,
|
||||||
|
int graphSize)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
if (initializerIndex == -1) {
|
if (initializerIndex == -1) {
|
||||||
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
return HnswGraphBuilder.create(
|
||||||
|
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
HnswGraph initializerGraph =
|
HnswGraph initializerGraph =
|
||||||
|
@ -499,7 +511,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
Map<Integer, Integer> ordinalMapper =
|
Map<Integer, Integer> ordinalMapper =
|
||||||
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
|
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
|
||||||
return HnswGraphBuilder.create(
|
return HnswGraphBuilder.create(
|
||||||
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper);
|
scorerSupplier,
|
||||||
|
M,
|
||||||
|
beamWidth,
|
||||||
|
HnswGraphBuilder.randSeed,
|
||||||
|
initializerGraph,
|
||||||
|
ordinalMapper,
|
||||||
|
graphSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
|
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
|
||||||
|
|
|
@ -66,6 +66,11 @@ public abstract class HnswGraph {
|
||||||
/** Returns the number of nodes in the graph */
|
/** Returns the number of nodes in the graph */
|
||||||
public abstract int size();
|
public abstract int size();
|
||||||
|
|
||||||
|
/** Returns max node id, inclusive, normally this value will be size - 1 */
|
||||||
|
public int maxNodeId() {
|
||||||
|
return size() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||||
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
||||||
|
|
|
@ -76,7 +76,12 @@ public final class HnswGraphBuilder {
|
||||||
public static HnswGraphBuilder create(
|
public static HnswGraphBuilder create(
|
||||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
|
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HnswGraphBuilder create(
|
||||||
|
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
|
||||||
|
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static HnswGraphBuilder create(
|
public static HnswGraphBuilder create(
|
||||||
|
@ -85,9 +90,11 @@ public final class HnswGraphBuilder {
|
||||||
int beamWidth,
|
int beamWidth,
|
||||||
long seed,
|
long seed,
|
||||||
HnswGraph initializerGraph,
|
HnswGraph initializerGraph,
|
||||||
Map<Integer, Integer> oldToNewOrdinalMap)
|
Map<Integer, Integer> oldToNewOrdinalMap,
|
||||||
|
int graphSize)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
|
HnswGraphBuilder hnswGraphBuilder =
|
||||||
|
new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||||
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
|
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
|
||||||
return hnswGraphBuilder;
|
return hnswGraphBuilder;
|
||||||
}
|
}
|
||||||
|
@ -102,10 +109,10 @@ public final class HnswGraphBuilder {
|
||||||
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
||||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||||
* to ensure repeatable construction.
|
* to ensure repeatable construction.
|
||||||
|
* @param graphSize size of graph, if unknown, pass in -1
|
||||||
*/
|
*/
|
||||||
private HnswGraphBuilder(
|
private HnswGraphBuilder(
|
||||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
|
||||||
throws IOException {
|
|
||||||
if (M <= 0) {
|
if (M <= 0) {
|
||||||
throw new IllegalArgumentException("maxConn must be positive");
|
throw new IllegalArgumentException("maxConn must be positive");
|
||||||
}
|
}
|
||||||
|
@ -118,7 +125,7 @@ public final class HnswGraphBuilder {
|
||||||
// normalization factor for level generation; currently not configurable
|
// normalization factor for level generation; currently not configurable
|
||||||
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
||||||
this.random = new SplittableRandom(seed);
|
this.random = new SplittableRandom(seed);
|
||||||
this.hnsw = new OnHeapHnswGraph(M);
|
this.hnsw = new OnHeapHnswGraph(M, graphSize);
|
||||||
this.graphSearcher =
|
this.graphSearcher =
|
||||||
new HnswGraphSearcher(
|
new HnswGraphSearcher(
|
||||||
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
|
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
|
||||||
|
@ -155,7 +162,7 @@ public final class HnswGraphBuilder {
|
||||||
private void initializeFromGraph(
|
private void initializeFromGraph(
|
||||||
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
|
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
|
||||||
assert hnsw.size() == 0;
|
assert hnsw.size() == 0;
|
||||||
for (int level = 0; level < initializerGraph.numLevels(); level++) {
|
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
|
||||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||||
|
|
||||||
while (it.hasNext()) {
|
while (it.hasNext()) {
|
||||||
|
@ -288,7 +295,7 @@ public final class HnswGraphBuilder {
|
||||||
// only adding it if it is closer to the target than to any of the other selected neighbors
|
// only adding it if it is closer to the target than to any of the other selected neighbors
|
||||||
int cNode = candidates.node[i];
|
int cNode = candidates.node[i];
|
||||||
float cScore = candidates.score[i];
|
float cScore = candidates.score[i];
|
||||||
assert cNode < hnsw.size();
|
assert cNode <= hnsw.maxNodeId();
|
||||||
if (diversityCheck(cNode, cScore, neighbors)) {
|
if (diversityCheck(cNode, cScore, neighbors)) {
|
||||||
neighbors.addInOrder(cNode, cScore);
|
neighbors.addInOrder(cNode, cScore);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class HnswGraphSearcher {
|
||||||
throws IOException {
|
throws IOException {
|
||||||
HnswGraphSearcher graphSearcher =
|
HnswGraphSearcher graphSearcher =
|
||||||
new HnswGraphSearcher(
|
new HnswGraphSearcher(
|
||||||
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size()));
|
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
|
||||||
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
|
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ public class HnswGraphSearcher {
|
||||||
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
|
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
|
||||||
OnHeapHnswGraphSearcher graphSearcher =
|
OnHeapHnswGraphSearcher graphSearcher =
|
||||||
new OnHeapHnswGraphSearcher(
|
new OnHeapHnswGraphSearcher(
|
||||||
new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size()));
|
new NeighborQueue(topK, true), new SparseFixedBitSet(getGraphSize(graph)));
|
||||||
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
|
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
|
||||||
return knnCollector;
|
return knnCollector;
|
||||||
}
|
}
|
||||||
|
@ -150,9 +150,9 @@ public class HnswGraphSearcher {
|
||||||
*/
|
*/
|
||||||
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
|
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int size = graph.size();
|
int size = getGraphSize(graph);
|
||||||
int visitedCount = 1;
|
int visitedCount = 1;
|
||||||
prepareScratchState(graph.size());
|
prepareScratchState(size);
|
||||||
int currentEp = graph.entryNode();
|
int currentEp = graph.entryNode();
|
||||||
float currentScore = scorer.score(currentEp);
|
float currentScore = scorer.score(currentEp);
|
||||||
boolean foundBetter;
|
boolean foundBetter;
|
||||||
|
@ -201,8 +201,9 @@ public class HnswGraphSearcher {
|
||||||
Bits acceptOrds)
|
Bits acceptOrds)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
int size = graph.size();
|
int size = getGraphSize(graph);
|
||||||
prepareScratchState(graph.size());
|
|
||||||
|
prepareScratchState(size);
|
||||||
|
|
||||||
for (int ep : eps) {
|
for (int ep : eps) {
|
||||||
if (visited.getAndSet(ep) == false) {
|
if (visited.getAndSet(ep) == false) {
|
||||||
|
@ -284,6 +285,10 @@ public class HnswGraphSearcher {
|
||||||
return graph.nextNeighbor();
|
return graph.nextNeighbor();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static int getGraphSize(HnswGraph graph) {
|
||||||
|
return graph.maxNodeId() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding
|
* This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding
|
||||||
* the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and
|
* the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and
|
||||||
|
|
|
@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import org.apache.lucene.util.Accountable;
|
import org.apache.lucene.util.Accountable;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -32,39 +31,56 @@ import org.apache.lucene.util.RamUsageEstimator;
|
||||||
*/
|
*/
|
||||||
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||||
|
|
||||||
|
private static final int INIT_SIZE = 128;
|
||||||
|
|
||||||
private int numLevels; // the current number of levels in the graph
|
private int numLevels; // the current number of levels in the graph
|
||||||
private int entryNode; // the current graph entry node on the top level. -1 if not set
|
private int entryNode; // the current graph entry node on the top level. -1 if not set
|
||||||
|
|
||||||
// Level 0 is represented as List<NeighborArray> – nodes' connections on level 0.
|
// the internal graph representation where the first dimension is node id and second dimension is
|
||||||
// Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
|
// level
|
||||||
// to vectors
|
// e.g. graph[1][2] is all the neighbours of node 1 at level 2
|
||||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
private NeighborArray[][] graph;
|
||||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
// essentially another 2d map which the first dimension is level and second dimension is node id,
|
||||||
private final List<NeighborArray> graphLevel0;
|
// this is only
|
||||||
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
|
// generated on demand when there's someone calling getNodeOnLevel on a non-zero level
|
||||||
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
|
private List<Integer>[] levelToNodes;
|
||||||
// it in this list. However, to avoid changing list indexing, we always will make the first
|
private int
|
||||||
// element
|
lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
|
||||||
// null.
|
// levelToNodes
|
||||||
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
|
private int size; // graph size, which is number of nodes in level 0
|
||||||
private final int nsize;
|
private int
|
||||||
private final int nsize0;
|
nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
|
||||||
|
// is only used to account memory usage
|
||||||
|
private int maxNodeId;
|
||||||
|
private final int nsize; // neighbour array size at non-zero level
|
||||||
|
private final int nsize0; // neighbour array size at zero level
|
||||||
|
private final boolean
|
||||||
|
noGrowth; // if an initial size is passed in, we don't expect the graph to grow itself
|
||||||
|
|
||||||
// KnnGraphValues iterator members
|
// KnnGraphValues iterator members
|
||||||
private int upto;
|
private int upto;
|
||||||
private NeighborArray cur;
|
private NeighborArray cur;
|
||||||
|
|
||||||
OnHeapHnswGraph(int M) {
|
/**
|
||||||
|
* ctor
|
||||||
|
*
|
||||||
|
* @param numNodes number of nodes that will be added to this graph, passing in -1 means unbounded
|
||||||
|
* while passing in a non-negative value will lock the whole graph and disable the graph from
|
||||||
|
* growing itself (you cannot add a node with has id >= numNodes)
|
||||||
|
*/
|
||||||
|
OnHeapHnswGraph(int M, int numNodes) {
|
||||||
this.numLevels = 1; // Implicitly start the graph with a single level
|
this.numLevels = 1; // Implicitly start the graph with a single level
|
||||||
this.graphLevel0 = new ArrayList<>();
|
|
||||||
this.entryNode = -1; // Entry node should be negative until a node is added
|
this.entryNode = -1; // Entry node should be negative until a node is added
|
||||||
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
|
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
|
||||||
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
|
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
|
||||||
|
this.maxNodeId = -1;
|
||||||
this.nsize = M + 1;
|
this.nsize = M + 1;
|
||||||
this.nsize0 = (M * 2 + 1);
|
this.nsize0 = (M * 2 + 1);
|
||||||
|
noGrowth = numNodes != -1;
|
||||||
this.graphUpperLevels = new ArrayList<>(numLevels);
|
if (noGrowth == false) {
|
||||||
graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes
|
numNodes = INIT_SIZE;
|
||||||
|
}
|
||||||
|
this.graph = new NeighborArray[numNodes][];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -74,23 +90,32 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||||
*/
|
*/
|
||||||
public NeighborArray getNeighbors(int level, int node) {
|
public NeighborArray getNeighbors(int level, int node) {
|
||||||
if (level == 0) {
|
assert graph[node][level] != null;
|
||||||
return graphLevel0.get(node);
|
return graph[node][level];
|
||||||
}
|
|
||||||
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
|
|
||||||
assert levelMap.containsKey(node);
|
|
||||||
return levelMap.get(node);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int size() {
|
public int size() {
|
||||||
return graphLevel0.size(); // all nodes are located on the 0th level
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* When we initialize from another graph, the max node id is different from {@link #size()},
|
||||||
|
* because we will add nodes out of order, such that we need two method for each
|
||||||
|
*
|
||||||
|
* @return max node id (inclusive)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int maxNodeId() {
|
||||||
|
return maxNodeId;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
|
* Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
|
||||||
* preceded by the node inserted out of order are eventually added.
|
* preceded by the node inserted out of order are eventually added.
|
||||||
*
|
*
|
||||||
|
* <p>NOTE: You must add a node starting from the node's top level
|
||||||
|
*
|
||||||
* @param level level to add a node on
|
* @param level level to add a node on
|
||||||
* @param node the node to add, represented as an ordinal on the level 0.
|
* @param node the node to add, represented as an ordinal on the level 0.
|
||||||
*/
|
*/
|
||||||
|
@ -99,28 +124,33 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||||
entryNode = node;
|
entryNode = node;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (level > 0) {
|
if (node >= graph.length) {
|
||||||
// if the new node introduces a new level, add more levels to the graph,
|
if (noGrowth) {
|
||||||
// and make this node the graph's new entry point
|
throw new IllegalStateException(
|
||||||
if (level >= numLevels) {
|
"The graph does not expect to grow when an initial size is given");
|
||||||
for (int i = numLevels; i <= level; i++) {
|
|
||||||
graphUpperLevels.add(new HashMap<>());
|
|
||||||
}
|
|
||||||
numLevels = level + 1;
|
|
||||||
entryNode = node;
|
|
||||||
}
|
|
||||||
|
|
||||||
graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true));
|
|
||||||
} else {
|
|
||||||
// Add nodes all the way up to and including "node" in the new graph on level 0. This will
|
|
||||||
// cause the size of the
|
|
||||||
// graph to differ from the number of nodes added to the graph. The size of the graph and the
|
|
||||||
// number of nodes
|
|
||||||
// added will only be in sync once all nodes from 0...last_node are added into the graph.
|
|
||||||
while (node >= graphLevel0.size()) {
|
|
||||||
graphLevel0.add(new NeighborArray(nsize0, true));
|
|
||||||
}
|
}
|
||||||
|
graph = ArrayUtil.grow(graph, node + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (level >= numLevels) {
|
||||||
|
numLevels = level + 1;
|
||||||
|
entryNode = node;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert graph[node] == null || graph[node].length > level
|
||||||
|
: "node must be inserted from the top level";
|
||||||
|
if (graph[node] == null) {
|
||||||
|
graph[node] =
|
||||||
|
new NeighborArray[level + 1]; // assumption: we always call this function from top level
|
||||||
|
size++;
|
||||||
|
}
|
||||||
|
if (level == 0) {
|
||||||
|
graph[node][level] = new NeighborArray(nsize0, true);
|
||||||
|
} else {
|
||||||
|
graph[node][level] = new NeighborArray(nsize, true);
|
||||||
|
nonZeroLevelSize++;
|
||||||
|
}
|
||||||
|
maxNodeId = Math.max(maxNodeId, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -158,50 +188,83 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||||
return entryNode;
|
return entryNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* WARN: calling this method will essentially iterate through all nodes at level 0 (even if you're
|
||||||
|
* not getting node at level 0), we have built some caching mechanism such that if graph is not
|
||||||
|
* changed only the first non-zero level call will pay the cost. So it is highly NOT recommended
|
||||||
|
* to call this method while the graph is still building.
|
||||||
|
*
|
||||||
|
* <p>NOTE: calling this method while the graph is still building is prohibited
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public NodesIterator getNodesOnLevel(int level) {
|
public NodesIterator getNodesOnLevel(int level) {
|
||||||
|
if (size() != maxNodeId() + 1) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"graph build not complete, size=" + size() + " maxNodeId=" + maxNodeId());
|
||||||
|
}
|
||||||
if (level == 0) {
|
if (level == 0) {
|
||||||
return new ArrayNodesIterator(size());
|
return new ArrayNodesIterator(size());
|
||||||
} else {
|
} else {
|
||||||
return new CollectionNodesIterator(graphUpperLevels.get(level).keySet());
|
generateLevelToNodes();
|
||||||
|
return new CollectionNodesIterator(levelToNodes[level]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings({"unchecked", "rawtypes"})
|
||||||
|
private void generateLevelToNodes() {
|
||||||
|
if (lastFreezeSize == size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
levelToNodes = new List[numLevels];
|
||||||
|
for (int i = 1; i < numLevels; i++) {
|
||||||
|
levelToNodes[i] = new ArrayList<>();
|
||||||
|
}
|
||||||
|
int nonNullNode = 0;
|
||||||
|
for (int node = 0; node < graph.length; node++) {
|
||||||
|
// when we init from another graph, we could have holes where some slot is null
|
||||||
|
if (graph[node] == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
nonNullNode++;
|
||||||
|
for (int i = 1; i < graph[node].length; i++) {
|
||||||
|
levelToNodes[i].add(node);
|
||||||
|
}
|
||||||
|
if (nonNullNode == size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastFreezeSize = size;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
long neighborArrayBytes0 =
|
long neighborArrayBytes0 =
|
||||||
nsize0 * (Integer.BYTES + Float.BYTES)
|
(long) nsize0 * (Integer.BYTES + Float.BYTES)
|
||||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
|
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
|
||||||
+ Integer.BYTES * 3;
|
+ Integer.BYTES * 3;
|
||||||
long neighborArrayBytes =
|
long neighborArrayBytes =
|
||||||
nsize * (Integer.BYTES + Float.BYTES)
|
(long) nsize * (Integer.BYTES + Float.BYTES)
|
||||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
|
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
|
||||||
+ Integer.BYTES * 3;
|
+ Integer.BYTES * 3;
|
||||||
long total = 0;
|
long total = 0;
|
||||||
for (int l = 0; l < numLevels; l++) {
|
total +=
|
||||||
if (l == 0) {
|
size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||||
total +=
|
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
|
||||||
graphLevel0.size() * neighborArrayBytes0
|
total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
|
||||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
|
total += 8 * Integer.BYTES; // all int fields
|
||||||
} else {
|
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
|
||||||
long numNodesOnLevel = graphUpperLevels.get(l).size();
|
total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
|
||||||
|
if (levelToNodes != null) {
|
||||||
// For levels > 0, we represent the graph structure with a tree map.
|
total +=
|
||||||
// A single node in the tree contains 3 references (left root, right root, value) as well
|
(long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
|
||||||
// as an Integer for the key and 1 extra byte for the color of the node (this is actually 1
|
total +=
|
||||||
// bit, but
|
(long) nonZeroLevelSize
|
||||||
// because we do not have that granularity, we set to 1 byte). In addition, we include 1
|
* (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||||
// more reference for
|
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||||
// the tree map itself.
|
+ Integer.BYTES);
|
||||||
total +=
|
|
||||||
numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1)
|
|
||||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
|
||||||
|
|
||||||
// Add the size neighbor of each node
|
|
||||||
total += numNodesOnLevel * neighborArrayBytes;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
|
@ -195,7 +195,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// test that sorted index returns the same search results are unsorted
|
// test that sorted index returns the same search results as unsorted
|
||||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||||
int dim = random().nextInt(10) + 3;
|
int dim = random().nextInt(10) + 3;
|
||||||
int nDoc = random().nextInt(200) + 100;
|
int nDoc = random().nextInt(200) + 100;
|
||||||
|
@ -454,77 +454,6 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException {
|
|
||||||
int maxNumLevels = randomIntBetween(2, 10);
|
|
||||||
int nodeCount = randomIntBetween(1, 100);
|
|
||||||
|
|
||||||
List<List<Integer>> nodesPerLevel = new ArrayList<>();
|
|
||||||
for (int i = 0; i < maxNumLevels; i++) {
|
|
||||||
nodesPerLevel.add(new ArrayList<>());
|
|
||||||
}
|
|
||||||
|
|
||||||
int numLevels = 0;
|
|
||||||
for (int currNode = 0; currNode < nodeCount; currNode++) {
|
|
||||||
int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1);
|
|
||||||
numLevels = Math.max(numLevels, nodeMaxLevel);
|
|
||||||
for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) {
|
|
||||||
nodesPerLevel.get(currLevel).add(currNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10);
|
|
||||||
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
|
|
||||||
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
|
|
||||||
int currLevelNodesSize = currLevelNodes.size();
|
|
||||||
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
|
|
||||||
topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10);
|
|
||||||
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
|
|
||||||
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
|
|
||||||
int currLevelNodesSize = currLevelNodes.size();
|
|
||||||
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
|
|
||||||
bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10);
|
|
||||||
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
|
|
||||||
List<Integer> currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel));
|
|
||||||
Collections.shuffle(currLevelNodes, random());
|
|
||||||
for (Integer currNode : currLevelNodes) {
|
|
||||||
topDownOrderRandomHnsw.addNode(currLevel, currNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10);
|
|
||||||
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
|
|
||||||
for (Integer currNode : nodesPerLevel.get(currLevel)) {
|
|
||||||
bottomUpExpectedHnsw.addNode(currLevel, currNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size());
|
|
||||||
for (Integer node : nodesPerLevel.get(0)) {
|
|
||||||
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
|
|
||||||
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
|
|
||||||
List<Integer> sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel);
|
|
||||||
assertEquals(
|
|
||||||
String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel),
|
|
||||||
expectedNodesOnLevel,
|
|
||||||
sortedNodes);
|
|
||||||
}
|
|
||||||
|
|
||||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
|
|
||||||
assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw);
|
|
||||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
|
public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
|
||||||
int totalSize = atLeast(100);
|
int totalSize = atLeast(100);
|
||||||
int initializerSize = random().nextInt(5, totalSize);
|
int initializerSize = random().nextInt(5, totalSize);
|
||||||
|
@ -547,7 +476,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||||
HnswGraphBuilder finalBuilder =
|
HnswGraphBuilder finalBuilder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
|
finalscorerSupplier,
|
||||||
|
10,
|
||||||
|
30,
|
||||||
|
seed,
|
||||||
|
initializerGraph,
|
||||||
|
initializerOrdMap,
|
||||||
|
finalVectorValues.size());
|
||||||
|
|
||||||
// When offset is 0, the graphs should be identical before vectors are added
|
// When offset is 0, the graphs should be identical before vectors are added
|
||||||
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
|
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
|
||||||
|
@ -577,7 +512,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||||
HnswGraphBuilder finalBuilder =
|
HnswGraphBuilder finalBuilder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
|
finalscorerSupplier,
|
||||||
|
10,
|
||||||
|
30,
|
||||||
|
seed,
|
||||||
|
initializerGraph,
|
||||||
|
initializerOrdMap,
|
||||||
|
finalVectorValues.size());
|
||||||
|
|
||||||
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
|
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
|
||||||
|
|
||||||
|
@ -599,35 +540,30 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertGraphInitializedFromGraph(
|
private void assertGraphInitializedFromGraph(
|
||||||
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
HnswGraph g, HnswGraph initializer, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
assertEquals(
|
||||||
|
"the number of levels in the graphs are different!",
|
||||||
|
initializer.numLevels(),
|
||||||
|
g.numLevels());
|
||||||
// Confirm that the size of the new graph includes all nodes up to an including the max new
|
// Confirm that the size of the new graph includes all nodes up to an including the max new
|
||||||
// ordinal in the old to
|
// ordinal in the old to
|
||||||
// new ordinal mapping
|
// new ordinal mapping
|
||||||
assertEquals(
|
assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size());
|
||||||
"the number of nodes in the graphs are different!",
|
|
||||||
g.size(),
|
|
||||||
Collections.max(oldToNewOrdMap.values()) + 1);
|
|
||||||
|
|
||||||
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph
|
// assert that all the node from initializer graph can be found in the new graph and
|
||||||
for (int level = 1; level < g.numLevels(); level++) {
|
// the neighbors from the old graph are successfully transferred to the new graph
|
||||||
List<Integer> nodesOnLevel = sortedNodesOnLevel(g, level);
|
|
||||||
List<Integer> nodesOnLevel2 =
|
|
||||||
sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList();
|
|
||||||
assertEquals(nodesOnLevel, nodesOnLevel2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert that the neighbors from the old graph are successfully transferred to the new graph
|
|
||||||
for (int level = 0; level < g.numLevels(); level++) {
|
for (int level = 0; level < g.numLevels(); level++) {
|
||||||
NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
|
NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level);
|
||||||
while (nodesOnLevel.hasNext()) {
|
while (nodesOnLevel.hasNext()) {
|
||||||
int node = nodesOnLevel.nextInt();
|
int node = nodesOnLevel.nextInt();
|
||||||
g.seek(level, oldToNewOrdMap.get(node));
|
g.seek(level, oldToNewOrdMap.get(node));
|
||||||
h.seek(level, node);
|
initializer.seek(level, node);
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"arcs differ for node " + node,
|
"arcs differ for node " + node,
|
||||||
getNeighborNodes(g),
|
getNeighborNodes(g),
|
||||||
getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet()));
|
getNeighborNodes(initializer).stream()
|
||||||
|
.map(oldToNewOrdMap::get)
|
||||||
|
.collect(Collectors.toSet()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test OnHeapHnswGraph's behavior specifically, for more complex test, see {@link
|
||||||
|
* HnswGraphTestCase}
|
||||||
|
*/
|
||||||
|
public class TestOnHeapHnswGraph extends LuceneTestCase {
|
||||||
|
|
||||||
|
/* assert exception will be thrown when we add out of bound node to a fixed size graph */
|
||||||
|
public void testNoGrowth() {
|
||||||
|
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 100);
|
||||||
|
expectThrows(IllegalStateException.class, () -> graph.addNode(1, 100));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* AssertionError will be thrown if we add a node not from top most level,
|
||||||
|
(likely NPE will be thrown in prod) */
|
||||||
|
public void testAddLevelOutOfOrder() {
|
||||||
|
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||||
|
graph.addNode(0, 0);
|
||||||
|
expectThrows(AssertionError.class, () -> graph.addNode(1, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
|
||||||
|
public void testIncompleteGraphThrow() {
|
||||||
|
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
|
||||||
|
graph.addNode(1, 0);
|
||||||
|
graph.addNode(0, 0);
|
||||||
|
assertEquals(1, graph.getNodesOnLevel(1).size());
|
||||||
|
graph.addNode(0, 5);
|
||||||
|
expectThrows(IllegalStateException.class, () -> graph.getNodesOnLevel(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGraphGrowth() {
|
||||||
|
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||||
|
List<List<Integer>> levelToNodes = new ArrayList<>();
|
||||||
|
int maxLevel = 5;
|
||||||
|
for (int i = 0; i < maxLevel; i++) {
|
||||||
|
levelToNodes.add(new ArrayList<>());
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 101; i++) {
|
||||||
|
int level = random().nextInt(maxLevel);
|
||||||
|
for (int l = level; l >= 0; l--) {
|
||||||
|
graph.addNode(l, i);
|
||||||
|
levelToNodes.get(l).add(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertGraphEquals(graph, levelToNodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGraphBuildOutOfOrder() {
|
||||||
|
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||||
|
List<List<Integer>> levelToNodes = new ArrayList<>();
|
||||||
|
int maxLevel = 5;
|
||||||
|
int numNodes = 100;
|
||||||
|
for (int i = 0; i < maxLevel; i++) {
|
||||||
|
levelToNodes.add(new ArrayList<>());
|
||||||
|
}
|
||||||
|
int[] insertions = new int[numNodes];
|
||||||
|
for (int i = 0; i < numNodes; i++) {
|
||||||
|
insertions[i] = i;
|
||||||
|
}
|
||||||
|
// shuffle the insertion order
|
||||||
|
for (int i = 0; i < 40; i++) {
|
||||||
|
int pos1 = random().nextInt(numNodes);
|
||||||
|
int pos2 = random().nextInt(numNodes);
|
||||||
|
int tmp = insertions[pos1];
|
||||||
|
insertions[pos1] = insertions[pos2];
|
||||||
|
insertions[pos2] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i : insertions) {
|
||||||
|
int level = random().nextInt(maxLevel);
|
||||||
|
for (int l = level; l >= 0; l--) {
|
||||||
|
graph.addNode(l, i);
|
||||||
|
levelToNodes.get(l).add(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < maxLevel; i++) {
|
||||||
|
levelToNodes.get(i).sort(Integer::compare);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertGraphEquals(graph, levelToNodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertGraphEquals(OnHeapHnswGraph graph, List<List<Integer>> levelToNodes) {
|
||||||
|
for (int l = 0; l < graph.numLevels(); l++) {
|
||||||
|
HnswGraph.NodesIterator nodesIterator = graph.getNodesOnLevel(l);
|
||||||
|
assertEquals(levelToNodes.get(l).size(), nodesIterator.size());
|
||||||
|
int idx = 0;
|
||||||
|
while (nodesIterator.hasNext()) {
|
||||||
|
assertEquals(levelToNodes.get(l).get(idx++), nodesIterator.next());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue