Optimize OnHeapHnswGraph's data structure (#12651)

make the internal graph representation a 2d array
This commit is contained in:
Patrick Zhai 2023-10-16 13:13:37 -07:00 committed by GitHub
parent 2482f7688b
commit a1cf22e6a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 344 additions and 187 deletions

View File

@ -195,6 +195,8 @@ Optimizations
* GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums.
(Adrien Grand)
* GITHUB#12651: Use 2d array for OnHeapHnswGraph representation. (Patrick Zhai)
Changes in runtime behavior
---------------------

View File

@ -56,7 +56,11 @@ public class Word2VecSynonymProvider {
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
HnswGraphBuilder builder =
HnswGraphBuilder.create(
scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed);
scorerSupplier,
DEFAULT_MAX_CONN,
DEFAULT_BEAM_WIDTH,
HnswGraphBuilder.randSeed,
word2VecModel.size());
this.hnswGraph = builder.build(word2VecModel.size());
}

View File

@ -438,7 +438,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
RandomVectorScorerSupplier.createBytes(
vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
createHnswGraphBuilder(
mergeState,
fieldInfo,
scorerSupplier,
initializerIndex,
vectorValues.size());
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.size());
}
@ -453,7 +458,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
RandomVectorScorerSupplier.createFloats(
vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
createHnswGraphBuilder(
mergeState,
fieldInfo,
scorerSupplier,
initializerIndex,
vectorValues.size());
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.size());
}
@ -488,10 +498,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
MergeState mergeState,
FieldInfo fieldInfo,
RandomVectorScorerSupplier scorerSupplier,
int initializerIndex)
int initializerIndex,
int graphSize)
throws IOException {
if (initializerIndex == -1) {
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize);
}
HnswGraph initializerGraph =
@ -499,7 +511,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper);
scorerSupplier,
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
ordinalMapper,
graphSize);
}
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)

View File

@ -66,6 +66,11 @@ public abstract class HnswGraph {
/** Returns the number of nodes in the graph */
public abstract int size();
/** Returns max node id, inclusive, normally this value will be size - 1 */
public int maxNodeId() {
return size() - 1;
}
/**
* Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.

View File

@ -76,7 +76,12 @@ public final class HnswGraphBuilder {
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
throws IOException {
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1);
}
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
}
public static HnswGraphBuilder create(
@ -85,9 +90,11 @@ public final class HnswGraphBuilder {
int beamWidth,
long seed,
HnswGraph initializerGraph,
Map<Integer, Integer> oldToNewOrdinalMap)
Map<Integer, Integer> oldToNewOrdinalMap,
int graphSize)
throws IOException {
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
return hnswGraphBuilder;
}
@ -102,10 +109,10 @@ public final class HnswGraphBuilder {
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
* @param graphSize size of graph, if unknown, pass in -1
*/
private HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
throws IOException {
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}
@ -118,7 +125,7 @@ public final class HnswGraphBuilder {
// normalization factor for level generation; currently not configurable
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
this.hnsw = new OnHeapHnswGraph(M);
this.hnsw = new OnHeapHnswGraph(M, graphSize);
this.graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
@ -155,7 +162,7 @@ public final class HnswGraphBuilder {
private void initializeFromGraph(
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
assert hnsw.size() == 0;
for (int level = 0; level < initializerGraph.numLevels(); level++) {
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
while (it.hasNext()) {
@ -288,7 +295,7 @@ public final class HnswGraphBuilder {
// only adding it if it is closer to the target than to any of the other selected neighbors
int cNode = candidates.node[i];
float cScore = candidates.score[i];
assert cNode < hnsw.size();
assert cNode <= hnsw.maxNodeId();
if (diversityCheck(cNode, cScore, neighbors)) {
neighbors.addInOrder(cNode, cScore);
}

View File

@ -66,7 +66,7 @@ public class HnswGraphSearcher {
throws IOException {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size()));
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
}
@ -88,7 +88,7 @@ public class HnswGraphSearcher {
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
OnHeapHnswGraphSearcher graphSearcher =
new OnHeapHnswGraphSearcher(
new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size()));
new NeighborQueue(topK, true), new SparseFixedBitSet(getGraphSize(graph)));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
return knnCollector;
}
@ -150,9 +150,9 @@ public class HnswGraphSearcher {
*/
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
throws IOException {
int size = graph.size();
int size = getGraphSize(graph);
int visitedCount = 1;
prepareScratchState(graph.size());
prepareScratchState(size);
int currentEp = graph.entryNode();
float currentScore = scorer.score(currentEp);
boolean foundBetter;
@ -201,8 +201,9 @@ public class HnswGraphSearcher {
Bits acceptOrds)
throws IOException {
int size = graph.size();
prepareScratchState(graph.size());
int size = getGraphSize(graph);
prepareScratchState(size);
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
@ -284,6 +285,10 @@ public class HnswGraphSearcher {
return graph.nextNeighbor();
}
private static int getGraphSize(HnswGraph graph) {
return graph.maxNodeId() + 1;
}
/**
* This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding
* the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and

View File

@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
/**
@ -32,39 +31,56 @@ import org.apache.lucene.util.RamUsageEstimator;
*/
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
private static final int INIT_SIZE = 128;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level. -1 if not set
// Level 0 is represented as List<NeighborArray> nodes' connections on level 0.
// Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
// to vectors
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
private final List<NeighborArray> graphLevel0;
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
// it in this list. However, to avoid changing list indexing, we always will make the first
// element
// null.
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
private final int nsize;
private final int nsize0;
// the internal graph representation where the first dimension is node id and second dimension is
// level
// e.g. graph[1][2] is all the neighbours of node 1 at level 2
private NeighborArray[][] graph;
// essentially another 2d map which the first dimension is level and second dimension is node id,
// this is only
// generated on demand when there's someone calling getNodeOnLevel on a non-zero level
private List<Integer>[] levelToNodes;
private int
lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
// levelToNodes
private int size; // graph size, which is number of nodes in level 0
private int
nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
// is only used to account memory usage
private int maxNodeId;
private final int nsize; // neighbour array size at non-zero level
private final int nsize0; // neighbour array size at zero level
private final boolean
noGrowth; // if an initial size is passed in, we don't expect the graph to grow itself
// KnnGraphValues iterator members
private int upto;
private NeighborArray cur;
OnHeapHnswGraph(int M) {
/**
* ctor
*
* @param numNodes number of nodes that will be added to this graph, passing in -1 means unbounded
* while passing in a non-negative value will lock the whole graph and disable the graph from
* growing itself (you cannot add a node with has id >= numNodes)
*/
OnHeapHnswGraph(int M, int numNodes) {
this.numLevels = 1; // Implicitly start the graph with a single level
this.graphLevel0 = new ArrayList<>();
this.entryNode = -1; // Entry node should be negative until a node is added
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
this.maxNodeId = -1;
this.nsize = M + 1;
this.nsize0 = (M * 2 + 1);
this.graphUpperLevels = new ArrayList<>(numLevels);
graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes
noGrowth = numNodes != -1;
if (noGrowth == false) {
numNodes = INIT_SIZE;
}
this.graph = new NeighborArray[numNodes][];
}
/**
@ -74,23 +90,32 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
*/
public NeighborArray getNeighbors(int level, int node) {
if (level == 0) {
return graphLevel0.get(node);
}
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
assert levelMap.containsKey(node);
return levelMap.get(node);
assert graph[node][level] != null;
return graph[node][level];
}
@Override
public int size() {
return graphLevel0.size(); // all nodes are located on the 0th level
return size;
}
/**
* When we initialize from another graph, the max node id is different from {@link #size()},
* because we will add nodes out of order, such that we need two method for each
*
* @return max node id (inclusive)
*/
@Override
public int maxNodeId() {
return maxNodeId;
}
/**
* Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
* preceded by the node inserted out of order are eventually added.
*
* <p>NOTE: You must add a node starting from the node's top level
*
* @param level level to add a node on
* @param node the node to add, represented as an ordinal on the level 0.
*/
@ -99,28 +124,33 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
entryNode = node;
}
if (level > 0) {
// if the new node introduces a new level, add more levels to the graph,
// and make this node the graph's new entry point
if (level >= numLevels) {
for (int i = numLevels; i <= level; i++) {
graphUpperLevels.add(new HashMap<>());
}
numLevels = level + 1;
entryNode = node;
}
graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true));
} else {
// Add nodes all the way up to and including "node" in the new graph on level 0. This will
// cause the size of the
// graph to differ from the number of nodes added to the graph. The size of the graph and the
// number of nodes
// added will only be in sync once all nodes from 0...last_node are added into the graph.
while (node >= graphLevel0.size()) {
graphLevel0.add(new NeighborArray(nsize0, true));
if (node >= graph.length) {
if (noGrowth) {
throw new IllegalStateException(
"The graph does not expect to grow when an initial size is given");
}
graph = ArrayUtil.grow(graph, node + 1);
}
if (level >= numLevels) {
numLevels = level + 1;
entryNode = node;
}
assert graph[node] == null || graph[node].length > level
: "node must be inserted from the top level";
if (graph[node] == null) {
graph[node] =
new NeighborArray[level + 1]; // assumption: we always call this function from top level
size++;
}
if (level == 0) {
graph[node][level] = new NeighborArray(nsize0, true);
} else {
graph[node][level] = new NeighborArray(nsize, true);
nonZeroLevelSize++;
}
maxNodeId = Math.max(maxNodeId, node);
}
@Override
@ -158,50 +188,83 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
return entryNode;
}
/**
* WARN: calling this method will essentially iterate through all nodes at level 0 (even if you're
* not getting node at level 0), we have built some caching mechanism such that if graph is not
* changed only the first non-zero level call will pay the cost. So it is highly NOT recommended
* to call this method while the graph is still building.
*
* <p>NOTE: calling this method while the graph is still building is prohibited
*/
@Override
public NodesIterator getNodesOnLevel(int level) {
if (size() != maxNodeId() + 1) {
throw new IllegalStateException(
"graph build not complete, size=" + size() + " maxNodeId=" + maxNodeId());
}
if (level == 0) {
return new ArrayNodesIterator(size());
} else {
return new CollectionNodesIterator(graphUpperLevels.get(level).keySet());
generateLevelToNodes();
return new CollectionNodesIterator(levelToNodes[level]);
}
}
@SuppressWarnings({"unchecked", "rawtypes"})
private void generateLevelToNodes() {
if (lastFreezeSize == size) {
return;
}
levelToNodes = new List[numLevels];
for (int i = 1; i < numLevels; i++) {
levelToNodes[i] = new ArrayList<>();
}
int nonNullNode = 0;
for (int node = 0; node < graph.length; node++) {
// when we init from another graph, we could have holes where some slot is null
if (graph[node] == null) {
continue;
}
nonNullNode++;
for (int i = 1; i < graph[node].length; i++) {
levelToNodes[i].add(node);
}
if (nonNullNode == size) {
break;
}
}
lastFreezeSize = size;
}
@Override
public long ramBytesUsed() {
long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES)
(long) nsize0 * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3;
long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES)
(long) nsize * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3;
long total = 0;
for (int l = 0; l < numLevels; l++) {
if (l == 0) {
total +=
graphLevel0.size() * neighborArrayBytes0
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
} else {
long numNodesOnLevel = graphUpperLevels.get(l).size();
// For levels > 0, we represent the graph structure with a tree map.
// A single node in the tree contains 3 references (left root, right root, value) as well
// as an Integer for the key and 1 extra byte for the color of the node (this is actually 1
// bit, but
// because we do not have that granularity, we set to 1 byte). In addition, we include 1
// more reference for
// the tree map itself.
total +=
numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1)
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
// Add the size neighbor of each node
total += numNodesOnLevel * neighborArrayBytes;
}
total +=
size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
total += 8 * Integer.BYTES; // all int fields
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
if (levelToNodes != null) {
total +=
(long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
total +=
(long) nonZeroLevelSize
* (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ Integer.BYTES);
}
return total;
}

View File

@ -195,7 +195,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
}
// test that sorted index returns the same search results are unsorted
// test that sorted index returns the same search results as unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3;
int nDoc = random().nextInt(200) + 100;
@ -454,77 +454,6 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
}
public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException {
int maxNumLevels = randomIntBetween(2, 10);
int nodeCount = randomIntBetween(1, 100);
List<List<Integer>> nodesPerLevel = new ArrayList<>();
for (int i = 0; i < maxNumLevels; i++) {
nodesPerLevel.add(new ArrayList<>());
}
int numLevels = 0;
for (int currNode = 0; currNode < nodeCount; currNode++) {
int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1);
numLevels = Math.max(numLevels, nodeMaxLevel);
for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) {
nodesPerLevel.get(currLevel).add(currNode);
}
}
OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10);
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
int currLevelNodesSize = currLevelNodes.size();
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
}
}
OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10);
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
int currLevelNodesSize = currLevelNodes.size();
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
}
}
OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10);
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
List<Integer> currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel));
Collections.shuffle(currLevelNodes, random());
for (Integer currNode : currLevelNodes) {
topDownOrderRandomHnsw.addNode(currLevel, currNode);
}
}
OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10);
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
for (Integer currNode : nodesPerLevel.get(currLevel)) {
bottomUpExpectedHnsw.addNode(currLevel, currNode);
}
}
assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size());
for (Integer node : nodesPerLevel.get(0)) {
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size());
}
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
List<Integer> sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel);
assertEquals(
String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel),
expectedNodesOnLevel,
sortedNodes);
}
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw);
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw);
}
public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
int totalSize = atLeast(100);
int initializerSize = random().nextInt(5, totalSize);
@ -547,7 +476,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create(
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
finalscorerSupplier,
10,
30,
seed,
initializerGraph,
initializerOrdMap,
finalVectorValues.size());
// When offset is 0, the graphs should be identical before vectors are added
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
@ -577,7 +512,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create(
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
finalscorerSupplier,
10,
30,
seed,
initializerGraph,
initializerOrdMap,
finalVectorValues.size());
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
@ -599,35 +540,30 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
private void assertGraphInitializedFromGraph(
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
HnswGraph g, HnswGraph initializer, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
assertEquals(
"the number of levels in the graphs are different!",
initializer.numLevels(),
g.numLevels());
// Confirm that the size of the new graph includes all nodes up to an including the max new
// ordinal in the old to
// new ordinal mapping
assertEquals(
"the number of nodes in the graphs are different!",
g.size(),
Collections.max(oldToNewOrdMap.values()) + 1);
assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size());
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph
for (int level = 1; level < g.numLevels(); level++) {
List<Integer> nodesOnLevel = sortedNodesOnLevel(g, level);
List<Integer> nodesOnLevel2 =
sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList();
assertEquals(nodesOnLevel, nodesOnLevel2);
}
// assert that the neighbors from the old graph are successfully transferred to the new graph
// assert that all the node from initializer graph can be found in the new graph and
// the neighbors from the old graph are successfully transferred to the new graph
for (int level = 0; level < g.numLevels(); level++) {
NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
g.seek(level, oldToNewOrdMap.get(node));
h.seek(level, node);
initializer.seek(level, node);
assertEquals(
"arcs differ for node " + node,
getNeighborNodes(g),
getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet()));
getNeighborNodes(initializer).stream()
.map(oldToNewOrdMap::get)
.collect(Collectors.toSet()));
}
}
}

View File

@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.tests.util.LuceneTestCase;
/**
* Test OnHeapHnswGraph's behavior specifically, for more complex test, see {@link
* HnswGraphTestCase}
*/
public class TestOnHeapHnswGraph extends LuceneTestCase {
/* assert exception will be thrown when we add out of bound node to a fixed size graph */
public void testNoGrowth() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 100);
expectThrows(IllegalStateException.class, () -> graph.addNode(1, 100));
}
/* AssertionError will be thrown if we add a node not from top most level,
(likely NPE will be thrown in prod) */
public void testAddLevelOutOfOrder() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
graph.addNode(0, 0);
expectThrows(AssertionError.class, () -> graph.addNode(1, 0));
}
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
public void testIncompleteGraphThrow() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
graph.addNode(1, 0);
graph.addNode(0, 0);
assertEquals(1, graph.getNodesOnLevel(1).size());
graph.addNode(0, 5);
expectThrows(IllegalStateException.class, () -> graph.getNodesOnLevel(0));
}
public void testGraphGrowth() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
List<List<Integer>> levelToNodes = new ArrayList<>();
int maxLevel = 5;
for (int i = 0; i < maxLevel; i++) {
levelToNodes.add(new ArrayList<>());
}
for (int i = 0; i < 101; i++) {
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
levelToNodes.get(l).add(i);
}
}
assertGraphEquals(graph, levelToNodes);
}
public void testGraphBuildOutOfOrder() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
List<List<Integer>> levelToNodes = new ArrayList<>();
int maxLevel = 5;
int numNodes = 100;
for (int i = 0; i < maxLevel; i++) {
levelToNodes.add(new ArrayList<>());
}
int[] insertions = new int[numNodes];
for (int i = 0; i < numNodes; i++) {
insertions[i] = i;
}
// shuffle the insertion order
for (int i = 0; i < 40; i++) {
int pos1 = random().nextInt(numNodes);
int pos2 = random().nextInt(numNodes);
int tmp = insertions[pos1];
insertions[pos1] = insertions[pos2];
insertions[pos2] = tmp;
}
for (int i : insertions) {
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
levelToNodes.get(l).add(i);
}
}
for (int i = 0; i < maxLevel; i++) {
levelToNodes.get(i).sort(Integer::compare);
}
assertGraphEquals(graph, levelToNodes);
}
private static void assertGraphEquals(OnHeapHnswGraph graph, List<List<Integer>> levelToNodes) {
for (int l = 0; l < graph.numLevels(); l++) {
HnswGraph.NodesIterator nodesIterator = graph.getNodesOnLevel(l);
assertEquals(levelToNodes.get(l).size(), nodesIterator.size());
int idx = 0;
while (nodesIterator.hasNext()) {
assertEquals(levelToNodes.get(l).get(idx++), nodesIterator.next());
}
}
}
}