Optimize OnHeapHnswGraph's data structure (#12651)

make the internal graph representation a 2d array
This commit is contained in:
Patrick Zhai 2023-10-16 13:13:37 -07:00 committed by GitHub
parent 2482f7688b
commit a1cf22e6a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 344 additions and 187 deletions

View File

@ -196,6 +196,8 @@ Optimizations
* GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums. * GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums.
(Adrien Grand) (Adrien Grand)
* GITHUB#12651: Use 2d array for OnHeapHnswGraph representation. (Patrick Zhai)
Changes in runtime behavior Changes in runtime behavior
--------------------- ---------------------

View File

@ -56,7 +56,11 @@ public class Word2VecSynonymProvider {
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION); RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
HnswGraphBuilder builder = HnswGraphBuilder builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed); scorerSupplier,
DEFAULT_MAX_CONN,
DEFAULT_BEAM_WIDTH,
HnswGraphBuilder.randSeed,
word2VecModel.size());
this.hnswGraph = builder.build(word2VecModel.size()); this.hnswGraph = builder.build(word2VecModel.size());
} }

View File

@ -438,7 +438,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
RandomVectorScorerSupplier.createBytes( RandomVectorScorerSupplier.createBytes(
vectorValues, fieldInfo.getVectorSimilarityFunction()); vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex); createHnswGraphBuilder(
mergeState,
fieldInfo,
scorerSupplier,
initializerIndex,
vectorValues.size());
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.size()); yield hnswGraphBuilder.build(vectorValues.size());
} }
@ -453,7 +458,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
RandomVectorScorerSupplier.createFloats( RandomVectorScorerSupplier.createFloats(
vectorValues, fieldInfo.getVectorSimilarityFunction()); vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex); createHnswGraphBuilder(
mergeState,
fieldInfo,
scorerSupplier,
initializerIndex,
vectorValues.size());
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.size()); yield hnswGraphBuilder.build(vectorValues.size());
} }
@ -488,10 +498,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
MergeState mergeState, MergeState mergeState,
FieldInfo fieldInfo, FieldInfo fieldInfo,
RandomVectorScorerSupplier scorerSupplier, RandomVectorScorerSupplier scorerSupplier,
int initializerIndex) int initializerIndex,
int graphSize)
throws IOException { throws IOException {
if (initializerIndex == -1) { if (initializerIndex == -1) {
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize);
} }
HnswGraph initializerGraph = HnswGraph initializerGraph =
@ -499,7 +511,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
Map<Integer, Integer> ordinalMapper = Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex); getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
return HnswGraphBuilder.create( return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper); scorerSupplier,
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
ordinalMapper,
graphSize);
} }
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo) private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)

View File

@ -66,6 +66,11 @@ public abstract class HnswGraph {
/** Returns the number of nodes in the graph */ /** Returns the number of nodes in the graph */
public abstract int size(); public abstract int size();
/** Returns max node id, inclusive, normally this value will be size - 1 */
public int maxNodeId() {
return size() - 1;
}
/** /**
* Iterates over the neighbor list. It is illegal to call this method after it returns * Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator. * NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.

View File

@ -76,7 +76,12 @@ public final class HnswGraphBuilder {
public static HnswGraphBuilder create( public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
throws IOException { throws IOException {
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed); return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1);
}
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
} }
public static HnswGraphBuilder create( public static HnswGraphBuilder create(
@ -85,9 +90,11 @@ public final class HnswGraphBuilder {
int beamWidth, int beamWidth,
long seed, long seed,
HnswGraph initializerGraph, HnswGraph initializerGraph,
Map<Integer, Integer> oldToNewOrdinalMap) Map<Integer, Integer> oldToNewOrdinalMap,
int graphSize)
throws IOException { throws IOException {
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed); HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap); hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
return hnswGraphBuilder; return hnswGraphBuilder;
} }
@ -102,10 +109,10 @@ public final class HnswGraphBuilder {
* @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this * @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction. * to ensure repeatable construction.
* @param graphSize size of graph, if unknown, pass in -1
*/ */
private HnswGraphBuilder( private HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
throws IOException {
if (M <= 0) { if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive"); throw new IllegalArgumentException("maxConn must be positive");
} }
@ -118,7 +125,7 @@ public final class HnswGraphBuilder {
// normalization factor for level generation; currently not configurable // normalization factor for level generation; currently not configurable
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed); this.random = new SplittableRandom(seed);
this.hnsw = new OnHeapHnswGraph(M); this.hnsw = new OnHeapHnswGraph(M, graphSize);
this.graphSearcher = this.graphSearcher =
new HnswGraphSearcher( new HnswGraphSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size())); new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
@ -155,7 +162,7 @@ public final class HnswGraphBuilder {
private void initializeFromGraph( private void initializeFromGraph(
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException { HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
assert hnsw.size() == 0; assert hnsw.size() == 0;
for (int level = 0; level < initializerGraph.numLevels(); level++) { for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
while (it.hasNext()) { while (it.hasNext()) {
@ -288,7 +295,7 @@ public final class HnswGraphBuilder {
// only adding it if it is closer to the target than to any of the other selected neighbors // only adding it if it is closer to the target than to any of the other selected neighbors
int cNode = candidates.node[i]; int cNode = candidates.node[i];
float cScore = candidates.score[i]; float cScore = candidates.score[i];
assert cNode < hnsw.size(); assert cNode <= hnsw.maxNodeId();
if (diversityCheck(cNode, cScore, neighbors)) { if (diversityCheck(cNode, cScore, neighbors)) {
neighbors.addInOrder(cNode, cScore); neighbors.addInOrder(cNode, cScore);
} }

View File

@ -66,7 +66,7 @@ public class HnswGraphSearcher {
throws IOException { throws IOException {
HnswGraphSearcher graphSearcher = HnswGraphSearcher graphSearcher =
new HnswGraphSearcher( new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size())); new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds); search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
} }
@ -88,7 +88,7 @@ public class HnswGraphSearcher {
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit); KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
OnHeapHnswGraphSearcher graphSearcher = OnHeapHnswGraphSearcher graphSearcher =
new OnHeapHnswGraphSearcher( new OnHeapHnswGraphSearcher(
new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size())); new NeighborQueue(topK, true), new SparseFixedBitSet(getGraphSize(graph)));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds); search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
return knnCollector; return knnCollector;
} }
@ -150,9 +150,9 @@ public class HnswGraphSearcher {
*/ */
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit) private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
throws IOException { throws IOException {
int size = graph.size(); int size = getGraphSize(graph);
int visitedCount = 1; int visitedCount = 1;
prepareScratchState(graph.size()); prepareScratchState(size);
int currentEp = graph.entryNode(); int currentEp = graph.entryNode();
float currentScore = scorer.score(currentEp); float currentScore = scorer.score(currentEp);
boolean foundBetter; boolean foundBetter;
@ -201,8 +201,9 @@ public class HnswGraphSearcher {
Bits acceptOrds) Bits acceptOrds)
throws IOException { throws IOException {
int size = graph.size(); int size = getGraphSize(graph);
prepareScratchState(graph.size());
prepareScratchState(size);
for (int ep : eps) { for (int ep : eps) {
if (visited.getAndSet(ep) == false) { if (visited.getAndSet(ep) == false) {
@ -284,6 +285,10 @@ public class HnswGraphSearcher {
return graph.nextNeighbor(); return graph.nextNeighbor();
} }
private static int getGraphSize(HnswGraph graph) {
return graph.maxNodeId() + 1;
}
/** /**
* This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding * This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding
* the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and * the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and

View File

@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
/** /**
@ -32,39 +31,56 @@ import org.apache.lucene.util.RamUsageEstimator;
*/ */
public final class OnHeapHnswGraph extends HnswGraph implements Accountable { public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
private static final int INIT_SIZE = 128;
private int numLevels; // the current number of levels in the graph private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level. -1 if not set private int entryNode; // the current graph entry node on the top level. -1 if not set
// Level 0 is represented as List<NeighborArray> nodes' connections on level 0. // the internal graph representation where the first dimension is node id and second dimension is
// Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond // level
// to vectors // e.g. graph[1][2] is all the neighbours of node 1 at level 2
// added to HnswBuilder, and the node values are the ordinals of those vectors. private NeighborArray[][] graph;
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. // essentially another 2d map which the first dimension is level and second dimension is node id,
private final List<NeighborArray> graphLevel0; // this is only
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0 // generated on demand when there's someone calling getNodeOnLevel on a non-zero level
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain private List<Integer>[] levelToNodes;
// it in this list. However, to avoid changing list indexing, we always will make the first private int
// element lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
// null. // levelToNodes
private final List<Map<Integer, NeighborArray>> graphUpperLevels; private int size; // graph size, which is number of nodes in level 0
private final int nsize; private int
private final int nsize0; nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
// is only used to account memory usage
private int maxNodeId;
private final int nsize; // neighbour array size at non-zero level
private final int nsize0; // neighbour array size at zero level
private final boolean
noGrowth; // if an initial size is passed in, we don't expect the graph to grow itself
// KnnGraphValues iterator members // KnnGraphValues iterator members
private int upto; private int upto;
private NeighborArray cur; private NeighborArray cur;
OnHeapHnswGraph(int M) { /**
* ctor
*
* @param numNodes number of nodes that will be added to this graph, passing in -1 means unbounded
* while passing in a non-negative value will lock the whole graph and disable the graph from
* growing itself (you cannot add a node with has id >= numNodes)
*/
OnHeapHnswGraph(int M, int numNodes) {
this.numLevels = 1; // Implicitly start the graph with a single level this.numLevels = 1; // Implicitly start the graph with a single level
this.graphLevel0 = new ArrayList<>();
this.entryNode = -1; // Entry node should be negative until a node is added this.entryNode = -1; // Entry node should be negative until a node is added
// Neighbours' size on upper levels (nsize) and level 0 (nsize0) // Neighbours' size on upper levels (nsize) and level 0 (nsize0)
// We allocate extra space for neighbours, but then prune them to keep allowed maximum // We allocate extra space for neighbours, but then prune them to keep allowed maximum
this.maxNodeId = -1;
this.nsize = M + 1; this.nsize = M + 1;
this.nsize0 = (M * 2 + 1); this.nsize0 = (M * 2 + 1);
noGrowth = numNodes != -1;
this.graphUpperLevels = new ArrayList<>(numLevels); if (noGrowth == false) {
graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes numNodes = INIT_SIZE;
}
this.graph = new NeighborArray[numNodes][];
} }
/** /**
@ -74,23 +90,32 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0. * @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
*/ */
public NeighborArray getNeighbors(int level, int node) { public NeighborArray getNeighbors(int level, int node) {
if (level == 0) { assert graph[node][level] != null;
return graphLevel0.get(node); return graph[node][level];
}
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
assert levelMap.containsKey(node);
return levelMap.get(node);
} }
@Override @Override
public int size() { public int size() {
return graphLevel0.size(); // all nodes are located on the 0th level return size;
}
/**
* When we initialize from another graph, the max node id is different from {@link #size()},
* because we will add nodes out of order, such that we need two method for each
*
* @return max node id (inclusive)
*/
@Override
public int maxNodeId() {
return maxNodeId;
} }
/** /**
* Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes * Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
* preceded by the node inserted out of order are eventually added. * preceded by the node inserted out of order are eventually added.
* *
* <p>NOTE: You must add a node starting from the node's top level
*
* @param level level to add a node on * @param level level to add a node on
* @param node the node to add, represented as an ordinal on the level 0. * @param node the node to add, represented as an ordinal on the level 0.
*/ */
@ -99,28 +124,33 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
entryNode = node; entryNode = node;
} }
if (level > 0) { if (node >= graph.length) {
// if the new node introduces a new level, add more levels to the graph, if (noGrowth) {
// and make this node the graph's new entry point throw new IllegalStateException(
if (level >= numLevels) { "The graph does not expect to grow when an initial size is given");
for (int i = numLevels; i <= level; i++) {
graphUpperLevels.add(new HashMap<>());
} }
graph = ArrayUtil.grow(graph, node + 1);
}
if (level >= numLevels) {
numLevels = level + 1; numLevels = level + 1;
entryNode = node; entryNode = node;
} }
graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true)); assert graph[node] == null || graph[node].length > level
: "node must be inserted from the top level";
if (graph[node] == null) {
graph[node] =
new NeighborArray[level + 1]; // assumption: we always call this function from top level
size++;
}
if (level == 0) {
graph[node][level] = new NeighborArray(nsize0, true);
} else { } else {
// Add nodes all the way up to and including "node" in the new graph on level 0. This will graph[node][level] = new NeighborArray(nsize, true);
// cause the size of the nonZeroLevelSize++;
// graph to differ from the number of nodes added to the graph. The size of the graph and the
// number of nodes
// added will only be in sync once all nodes from 0...last_node are added into the graph.
while (node >= graphLevel0.size()) {
graphLevel0.add(new NeighborArray(nsize0, true));
}
} }
maxNodeId = Math.max(maxNodeId, node);
} }
@Override @Override
@ -158,50 +188,83 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
return entryNode; return entryNode;
} }
/**
* WARN: calling this method will essentially iterate through all nodes at level 0 (even if you're
* not getting node at level 0), we have built some caching mechanism such that if graph is not
* changed only the first non-zero level call will pay the cost. So it is highly NOT recommended
* to call this method while the graph is still building.
*
* <p>NOTE: calling this method while the graph is still building is prohibited
*/
@Override @Override
public NodesIterator getNodesOnLevel(int level) { public NodesIterator getNodesOnLevel(int level) {
if (size() != maxNodeId() + 1) {
throw new IllegalStateException(
"graph build not complete, size=" + size() + " maxNodeId=" + maxNodeId());
}
if (level == 0) { if (level == 0) {
return new ArrayNodesIterator(size()); return new ArrayNodesIterator(size());
} else { } else {
return new CollectionNodesIterator(graphUpperLevels.get(level).keySet()); generateLevelToNodes();
return new CollectionNodesIterator(levelToNodes[level]);
} }
} }
@SuppressWarnings({"unchecked", "rawtypes"})
private void generateLevelToNodes() {
if (lastFreezeSize == size) {
return;
}
levelToNodes = new List[numLevels];
for (int i = 1; i < numLevels; i++) {
levelToNodes[i] = new ArrayList<>();
}
int nonNullNode = 0;
for (int node = 0; node < graph.length; node++) {
// when we init from another graph, we could have holes where some slot is null
if (graph[node] == null) {
continue;
}
nonNullNode++;
for (int i = 1; i < graph[node].length; i++) {
levelToNodes[i].add(node);
}
if (nonNullNode == size) {
break;
}
}
lastFreezeSize = size;
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long neighborArrayBytes0 = long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES) (long) nsize0 * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3; + Integer.BYTES * 3;
long neighborArrayBytes = long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES) (long) nsize * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3; + Integer.BYTES * 3;
long total = 0; long total = 0;
for (int l = 0; l < numLevels; l++) {
if (l == 0) {
total += total +=
graphLevel0.size() * neighborArrayBytes0 size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
} else { total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
long numNodesOnLevel = graphUpperLevels.get(l).size(); total += 8 * Integer.BYTES; // all int fields
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
// For levels > 0, we represent the graph structure with a tree map. total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
// A single node in the tree contains 3 references (left root, right root, value) as well if (levelToNodes != null) {
// as an Integer for the key and 1 extra byte for the color of the node (this is actually 1
// bit, but
// because we do not have that granularity, we set to 1 byte). In addition, we include 1
// more reference for
// the tree map itself.
total += total +=
numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1) (long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; total +=
(long) nonZeroLevelSize
// Add the size neighbor of each node * (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
total += numNodesOnLevel * neighborArrayBytes; + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
} + Integer.BYTES);
} }
return total; return total;
} }

View File

@ -195,7 +195,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
} }
} }
// test that sorted index returns the same search results are unsorted // test that sorted index returns the same search results as unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3; int dim = random().nextInt(10) + 3;
int nDoc = random().nextInt(200) + 100; int nDoc = random().nextInt(200) + 100;
@ -454,77 +454,6 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
} }
} }
public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException {
int maxNumLevels = randomIntBetween(2, 10);
int nodeCount = randomIntBetween(1, 100);
List<List<Integer>> nodesPerLevel = new ArrayList<>();
for (int i = 0; i < maxNumLevels; i++) {
nodesPerLevel.add(new ArrayList<>());
}
int numLevels = 0;
for (int currNode = 0; currNode < nodeCount; currNode++) {
int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1);
numLevels = Math.max(numLevels, nodeMaxLevel);
for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) {
nodesPerLevel.get(currLevel).add(currNode);
}
}
OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10);
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
int currLevelNodesSize = currLevelNodes.size();
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
}
}
OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10);
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
int currLevelNodesSize = currLevelNodes.size();
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
}
}
OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10);
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
List<Integer> currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel));
Collections.shuffle(currLevelNodes, random());
for (Integer currNode : currLevelNodes) {
topDownOrderRandomHnsw.addNode(currLevel, currNode);
}
}
OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10);
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
for (Integer currNode : nodesPerLevel.get(currLevel)) {
bottomUpExpectedHnsw.addNode(currLevel, currNode);
}
}
assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size());
for (Integer node : nodesPerLevel.get(0)) {
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size());
}
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
List<Integer> sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel);
assertEquals(
String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel),
expectedNodesOnLevel,
sortedNodes);
}
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw);
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw);
}
public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException { public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
int totalSize = atLeast(100); int totalSize = atLeast(100);
int initializerSize = random().nextInt(5, totalSize); int initializerSize = random().nextInt(5, totalSize);
@ -547,7 +476,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder = HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap); finalscorerSupplier,
10,
30,
seed,
initializerGraph,
initializerOrdMap,
finalVectorValues.size());
// When offset is 0, the graphs should be identical before vectors are added // When offset is 0, the graphs should be identical before vectors are added
assertGraphEqual(initializerGraph, finalBuilder.getGraph()); assertGraphEqual(initializerGraph, finalBuilder.getGraph());
@ -577,7 +512,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder = HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap); finalscorerSupplier,
10,
30,
seed,
initializerGraph,
initializerOrdMap,
finalVectorValues.size());
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap); assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
@ -599,35 +540,30 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
} }
private void assertGraphInitializedFromGraph( private void assertGraphInitializedFromGraph(
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException { HnswGraph g, HnswGraph initializer, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); assertEquals(
"the number of levels in the graphs are different!",
initializer.numLevels(),
g.numLevels());
// Confirm that the size of the new graph includes all nodes up to an including the max new // Confirm that the size of the new graph includes all nodes up to an including the max new
// ordinal in the old to // ordinal in the old to
// new ordinal mapping // new ordinal mapping
assertEquals( assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size());
"the number of nodes in the graphs are different!",
g.size(),
Collections.max(oldToNewOrdMap.values()) + 1);
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph // assert that all the node from initializer graph can be found in the new graph and
for (int level = 1; level < g.numLevels(); level++) { // the neighbors from the old graph are successfully transferred to the new graph
List<Integer> nodesOnLevel = sortedNodesOnLevel(g, level);
List<Integer> nodesOnLevel2 =
sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList();
assertEquals(nodesOnLevel, nodesOnLevel2);
}
// assert that the neighbors from the old graph are successfully transferred to the new graph
for (int level = 0; level < g.numLevels(); level++) { for (int level = 0; level < g.numLevels(); level++) {
NodesIterator nodesOnLevel = h.getNodesOnLevel(level); NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) { while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt(); int node = nodesOnLevel.nextInt();
g.seek(level, oldToNewOrdMap.get(node)); g.seek(level, oldToNewOrdMap.get(node));
h.seek(level, node); initializer.seek(level, node);
assertEquals( assertEquals(
"arcs differ for node " + node, "arcs differ for node " + node,
getNeighborNodes(g), getNeighborNodes(g),
getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet())); getNeighborNodes(initializer).stream()
.map(oldToNewOrdMap::get)
.collect(Collectors.toSet()));
} }
} }
} }

View File

@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.tests.util.LuceneTestCase;
/**
* Test OnHeapHnswGraph's behavior specifically, for more complex test, see {@link
* HnswGraphTestCase}
*/
public class TestOnHeapHnswGraph extends LuceneTestCase {
/* assert exception will be thrown when we add out of bound node to a fixed size graph */
public void testNoGrowth() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 100);
expectThrows(IllegalStateException.class, () -> graph.addNode(1, 100));
}
/* AssertionError will be thrown if we add a node not from top most level,
(likely NPE will be thrown in prod) */
public void testAddLevelOutOfOrder() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
graph.addNode(0, 0);
expectThrows(AssertionError.class, () -> graph.addNode(1, 0));
}
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
public void testIncompleteGraphThrow() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
graph.addNode(1, 0);
graph.addNode(0, 0);
assertEquals(1, graph.getNodesOnLevel(1).size());
graph.addNode(0, 5);
expectThrows(IllegalStateException.class, () -> graph.getNodesOnLevel(0));
}
public void testGraphGrowth() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
List<List<Integer>> levelToNodes = new ArrayList<>();
int maxLevel = 5;
for (int i = 0; i < maxLevel; i++) {
levelToNodes.add(new ArrayList<>());
}
for (int i = 0; i < 101; i++) {
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
levelToNodes.get(l).add(i);
}
}
assertGraphEquals(graph, levelToNodes);
}
public void testGraphBuildOutOfOrder() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
List<List<Integer>> levelToNodes = new ArrayList<>();
int maxLevel = 5;
int numNodes = 100;
for (int i = 0; i < maxLevel; i++) {
levelToNodes.add(new ArrayList<>());
}
int[] insertions = new int[numNodes];
for (int i = 0; i < numNodes; i++) {
insertions[i] = i;
}
// shuffle the insertion order
for (int i = 0; i < 40; i++) {
int pos1 = random().nextInt(numNodes);
int pos2 = random().nextInt(numNodes);
int tmp = insertions[pos1];
insertions[pos1] = insertions[pos2];
insertions[pos2] = tmp;
}
for (int i : insertions) {
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
levelToNodes.get(l).add(i);
}
}
for (int i = 0; i < maxLevel; i++) {
levelToNodes.get(i).sort(Integer::compare);
}
assertGraphEquals(graph, levelToNodes);
}
private static void assertGraphEquals(OnHeapHnswGraph graph, List<List<Integer>> levelToNodes) {
for (int l = 0; l < graph.numLevels(); l++) {
HnswGraph.NodesIterator nodesIterator = graph.getNodesOnLevel(l);
assertEquals(levelToNodes.get(l).size(), nodesIterator.size());
int idx = 0;
while (nodesIterator.hasNext()) {
assertEquals(levelToNodes.get(l).get(idx++), nodesIterator.next());
}
}
}
}