diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5846480e6c1..662f57f04b7 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -195,6 +195,8 @@ Optimizations * GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums. (Adrien Grand) + +* GITHUB#12651: Use 2d array for OnHeapHnswGraph representation. (Patrick Zhai) Changes in runtime behavior --------------------- diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java index be5dd7f4d2c..d7a37bbba9b 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -56,7 +56,11 @@ public class Word2VecSynonymProvider { RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION); HnswGraphBuilder builder = HnswGraphBuilder.create( - scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed); + scorerSupplier, + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + HnswGraphBuilder.randSeed, + word2VecModel.size()); this.hnswGraph = builder.build(word2VecModel.size()); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java index ee7dede5a7e..10bef300c2c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -438,7 +438,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { RandomVectorScorerSupplier.createBytes( vectorValues, fieldInfo.getVectorSimilarityFunction()); HnswGraphBuilder hnswGraphBuilder = - createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex); + createHnswGraphBuilder( + mergeState, + fieldInfo, + scorerSupplier, + initializerIndex, + vectorValues.size()); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); yield hnswGraphBuilder.build(vectorValues.size()); } @@ -453,7 +458,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { RandomVectorScorerSupplier.createFloats( vectorValues, fieldInfo.getVectorSimilarityFunction()); HnswGraphBuilder hnswGraphBuilder = - createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex); + createHnswGraphBuilder( + mergeState, + fieldInfo, + scorerSupplier, + initializerIndex, + vectorValues.size()); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); yield hnswGraphBuilder.build(vectorValues.size()); } @@ -488,10 +498,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { MergeState mergeState, FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier, - int initializerIndex) + int initializerIndex, + int graphSize) throws IOException { if (initializerIndex == -1) { - return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); + return HnswGraphBuilder.create( + scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize); } HnswGraph initializerGraph = @@ -499,7 +511,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { Map ordinalMapper = getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex); return HnswGraphBuilder.create( - scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper); + scorerSupplier, + M, + beamWidth, + HnswGraphBuilder.randSeed, + initializerGraph, + ordinalMapper, + graphSize); } private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index 9b3d0d62c90..54ff49f0602 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -66,6 +66,11 @@ public abstract class HnswGraph { /** Returns the number of nodes in the graph */ public abstract int size(); + /** Returns max node id, inclusive, normally this value will be size - 1 */ + public int maxNodeId() { + return size() - 1; + } + /** * Iterates over the neighbor list. It is illegal to call this method after it returns * NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator. diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 8fb85df372b..48e6cd219d3 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -76,7 +76,12 @@ public final class HnswGraphBuilder { public static HnswGraphBuilder create( RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException { - return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed); + return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1); + } + + public static HnswGraphBuilder create( + RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) { + return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize); } public static HnswGraphBuilder create( @@ -85,9 +90,11 @@ public final class HnswGraphBuilder { int beamWidth, long seed, HnswGraph initializerGraph, - Map oldToNewOrdinalMap) + Map oldToNewOrdinalMap, + int graphSize) throws IOException { - HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed); + HnswGraphBuilder hnswGraphBuilder = + new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize); hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap); return hnswGraphBuilder; } @@ -102,10 +109,10 @@ public final class HnswGraphBuilder { * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param seed the seed for a random number generator used during graph construction. Provide this * to ensure repeatable construction. + * @param graphSize size of graph, if unknown, pass in -1 */ private HnswGraphBuilder( - RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) - throws IOException { + RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) { if (M <= 0) { throw new IllegalArgumentException("maxConn must be positive"); } @@ -118,7 +125,7 @@ public final class HnswGraphBuilder { // normalization factor for level generation; currently not configurable this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); - this.hnsw = new OnHeapHnswGraph(M); + this.hnsw = new OnHeapHnswGraph(M, graphSize); this.graphSearcher = new HnswGraphSearcher( new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size())); @@ -155,7 +162,7 @@ public final class HnswGraphBuilder { private void initializeFromGraph( HnswGraph initializerGraph, Map oldToNewOrdinalMap) throws IOException { assert hnsw.size() == 0; - for (int level = 0; level < initializerGraph.numLevels(); level++) { + for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) { HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); while (it.hasNext()) { @@ -288,7 +295,7 @@ public final class HnswGraphBuilder { // only adding it if it is closer to the target than to any of the other selected neighbors int cNode = candidates.node[i]; float cScore = candidates.score[i]; - assert cNode < hnsw.size(); + assert cNode <= hnsw.maxNodeId(); if (diversityCheck(cNode, cScore, neighbors)) { neighbors.addInOrder(cNode, cScore); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index aeddbeb56fa..c18c16c38f5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -66,7 +66,7 @@ public class HnswGraphSearcher { throws IOException { HnswGraphSearcher graphSearcher = new HnswGraphSearcher( - new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size())); + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); search(scorer, knnCollector, graph, graphSearcher, acceptOrds); } @@ -88,7 +88,7 @@ public class HnswGraphSearcher { KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit); OnHeapHnswGraphSearcher graphSearcher = new OnHeapHnswGraphSearcher( - new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size())); + new NeighborQueue(topK, true), new SparseFixedBitSet(getGraphSize(graph))); search(scorer, knnCollector, graph, graphSearcher, acceptOrds); return knnCollector; } @@ -150,9 +150,9 @@ public class HnswGraphSearcher { */ private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit) throws IOException { - int size = graph.size(); + int size = getGraphSize(graph); int visitedCount = 1; - prepareScratchState(graph.size()); + prepareScratchState(size); int currentEp = graph.entryNode(); float currentScore = scorer.score(currentEp); boolean foundBetter; @@ -201,8 +201,9 @@ public class HnswGraphSearcher { Bits acceptOrds) throws IOException { - int size = graph.size(); - prepareScratchState(graph.size()); + int size = getGraphSize(graph); + + prepareScratchState(size); for (int ep : eps) { if (visited.getAndSet(ep) == false) { @@ -284,6 +285,10 @@ public class HnswGraphSearcher { return graph.nextNeighbor(); } + private static int getGraphSize(HnswGraph graph) { + return graph.maxNodeId() + 1; + } + /** * This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding * the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index b7f0ecfd075..bcb78e85182 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; /** @@ -32,39 +31,56 @@ import org.apache.lucene.util.RamUsageEstimator; */ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { + private static final int INIT_SIZE = 128; + private int numLevels; // the current number of levels in the graph private int entryNode; // the current graph entry node on the top level. -1 if not set - // Level 0 is represented as List – nodes' connections on level 0. - // Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond - // to vectors - // added to HnswBuilder, and the node values are the ordinals of those vectors. - // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. - private final List graphLevel0; - // Represents levels 1-N. Each level is represented with a Map that maps a levels level 0 - // ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain - // it in this list. However, to avoid changing list indexing, we always will make the first - // element - // null. - private final List> graphUpperLevels; - private final int nsize; - private final int nsize0; + // the internal graph representation where the first dimension is node id and second dimension is + // level + // e.g. graph[1][2] is all the neighbours of node 1 at level 2 + private NeighborArray[][] graph; + // essentially another 2d map which the first dimension is level and second dimension is node id, + // this is only + // generated on demand when there's someone calling getNodeOnLevel on a non-zero level + private List[] levelToNodes; + private int + lastFreezeSize; // remember the size we are at last time to freeze the graph and generate + // levelToNodes + private int size; // graph size, which is number of nodes in level 0 + private int + nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it + // is only used to account memory usage + private int maxNodeId; + private final int nsize; // neighbour array size at non-zero level + private final int nsize0; // neighbour array size at zero level + private final boolean + noGrowth; // if an initial size is passed in, we don't expect the graph to grow itself // KnnGraphValues iterator members private int upto; private NeighborArray cur; - OnHeapHnswGraph(int M) { + /** + * ctor + * + * @param numNodes number of nodes that will be added to this graph, passing in -1 means unbounded + * while passing in a non-negative value will lock the whole graph and disable the graph from + * growing itself (you cannot add a node with has id >= numNodes) + */ + OnHeapHnswGraph(int M, int numNodes) { this.numLevels = 1; // Implicitly start the graph with a single level - this.graphLevel0 = new ArrayList<>(); this.entryNode = -1; // Entry node should be negative until a node is added // Neighbours' size on upper levels (nsize) and level 0 (nsize0) // We allocate extra space for neighbours, but then prune them to keep allowed maximum + this.maxNodeId = -1; this.nsize = M + 1; this.nsize0 = (M * 2 + 1); - - this.graphUpperLevels = new ArrayList<>(numLevels); - graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes + noGrowth = numNodes != -1; + if (noGrowth == false) { + numNodes = INIT_SIZE; + } + this.graph = new NeighborArray[numNodes][]; } /** @@ -74,23 +90,32 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. */ public NeighborArray getNeighbors(int level, int node) { - if (level == 0) { - return graphLevel0.get(node); - } - Map levelMap = graphUpperLevels.get(level); - assert levelMap.containsKey(node); - return levelMap.get(node); + assert graph[node][level] != null; + return graph[node][level]; } @Override public int size() { - return graphLevel0.size(); // all nodes are located on the 0th level + return size; + } + + /** + * When we initialize from another graph, the max node id is different from {@link #size()}, + * because we will add nodes out of order, such that we need two method for each + * + * @return max node id (inclusive) + */ + @Override + public int maxNodeId() { + return maxNodeId; } /** * Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes * preceded by the node inserted out of order are eventually added. * + *

NOTE: You must add a node starting from the node's top level + * * @param level level to add a node on * @param node the node to add, represented as an ordinal on the level 0. */ @@ -99,28 +124,33 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { entryNode = node; } - if (level > 0) { - // if the new node introduces a new level, add more levels to the graph, - // and make this node the graph's new entry point - if (level >= numLevels) { - for (int i = numLevels; i <= level; i++) { - graphUpperLevels.add(new HashMap<>()); - } - numLevels = level + 1; - entryNode = node; - } - - graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true)); - } else { - // Add nodes all the way up to and including "node" in the new graph on level 0. This will - // cause the size of the - // graph to differ from the number of nodes added to the graph. The size of the graph and the - // number of nodes - // added will only be in sync once all nodes from 0...last_node are added into the graph. - while (node >= graphLevel0.size()) { - graphLevel0.add(new NeighborArray(nsize0, true)); + if (node >= graph.length) { + if (noGrowth) { + throw new IllegalStateException( + "The graph does not expect to grow when an initial size is given"); } + graph = ArrayUtil.grow(graph, node + 1); } + + if (level >= numLevels) { + numLevels = level + 1; + entryNode = node; + } + + assert graph[node] == null || graph[node].length > level + : "node must be inserted from the top level"; + if (graph[node] == null) { + graph[node] = + new NeighborArray[level + 1]; // assumption: we always call this function from top level + size++; + } + if (level == 0) { + graph[node][level] = new NeighborArray(nsize0, true); + } else { + graph[node][level] = new NeighborArray(nsize, true); + nonZeroLevelSize++; + } + maxNodeId = Math.max(maxNodeId, node); } @Override @@ -158,50 +188,83 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { return entryNode; } + /** + * WARN: calling this method will essentially iterate through all nodes at level 0 (even if you're + * not getting node at level 0), we have built some caching mechanism such that if graph is not + * changed only the first non-zero level call will pay the cost. So it is highly NOT recommended + * to call this method while the graph is still building. + * + *

NOTE: calling this method while the graph is still building is prohibited + */ @Override public NodesIterator getNodesOnLevel(int level) { + if (size() != maxNodeId() + 1) { + throw new IllegalStateException( + "graph build not complete, size=" + size() + " maxNodeId=" + maxNodeId()); + } if (level == 0) { return new ArrayNodesIterator(size()); } else { - return new CollectionNodesIterator(graphUpperLevels.get(level).keySet()); + generateLevelToNodes(); + return new CollectionNodesIterator(levelToNodes[level]); } } + @SuppressWarnings({"unchecked", "rawtypes"}) + private void generateLevelToNodes() { + if (lastFreezeSize == size) { + return; + } + + levelToNodes = new List[numLevels]; + for (int i = 1; i < numLevels; i++) { + levelToNodes[i] = new ArrayList<>(); + } + int nonNullNode = 0; + for (int node = 0; node < graph.length; node++) { + // when we init from another graph, we could have holes where some slot is null + if (graph[node] == null) { + continue; + } + nonNullNode++; + for (int i = 1; i < graph[node].length; i++) { + levelToNodes[i].add(node); + } + if (nonNullNode == size) { + break; + } + } + lastFreezeSize = size; + } + @Override public long ramBytesUsed() { long neighborArrayBytes0 = - nsize0 * (Integer.BYTES + Float.BYTES) + (long) nsize0 * (Integer.BYTES + Float.BYTES) + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER - + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L + Integer.BYTES * 3; long neighborArrayBytes = - nsize * (Integer.BYTES + Float.BYTES) + (long) nsize * (Integer.BYTES + Float.BYTES) + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER - + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L + Integer.BYTES * 3; long total = 0; - for (int l = 0; l < numLevels; l++) { - if (l == 0) { - total += - graphLevel0.size() * neighborArrayBytes0 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; - } else { - long numNodesOnLevel = graphUpperLevels.get(l).size(); - - // For levels > 0, we represent the graph structure with a tree map. - // A single node in the tree contains 3 references (left root, right root, value) as well - // as an Integer for the key and 1 extra byte for the color of the node (this is actually 1 - // bit, but - // because we do not have that granularity, we set to 1 byte). In addition, we include 1 - // more reference for - // the tree map itself. - total += - numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1) - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; - - // Add the size neighbor of each node - total += numNodesOnLevel * neighborArrayBytes; - } + total += + size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0; + total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level + total += 8 * Integer.BYTES; // all int fields + total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur + total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes + if (levelToNodes != null) { + total += + (long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0 + total += + (long) nonZeroLevelSize + * (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + Integer.BYTES); } return total; } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index b758b441c50..06f79a2e63a 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -195,7 +195,7 @@ abstract class HnswGraphTestCase extends LuceneTestCase { } } - // test that sorted index returns the same search results are unsorted + // test that sorted index returns the same search results as unsorted public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { int dim = random().nextInt(10) + 3; int nDoc = random().nextInt(200) + 100; @@ -454,77 +454,6 @@ abstract class HnswGraphTestCase extends LuceneTestCase { } } - public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { - int maxNumLevels = randomIntBetween(2, 10); - int nodeCount = randomIntBetween(1, 100); - - List> nodesPerLevel = new ArrayList<>(); - for (int i = 0; i < maxNumLevels; i++) { - nodesPerLevel.add(new ArrayList<>()); - } - - int numLevels = 0; - for (int currNode = 0; currNode < nodeCount; currNode++) { - int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1); - numLevels = Math.max(numLevels, nodeMaxLevel); - for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) { - nodesPerLevel.get(currLevel).add(currNode); - } - } - - OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10); - for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { - List currLevelNodes = nodesPerLevel.get(currLevel); - int currLevelNodesSize = currLevelNodes.size(); - for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) { - topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd)); - } - } - - OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10); - for (int currLevel = 0; currLevel < numLevels; currLevel++) { - List currLevelNodes = nodesPerLevel.get(currLevel); - int currLevelNodesSize = currLevelNodes.size(); - for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) { - bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd)); - } - } - - OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10); - for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { - List currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel)); - Collections.shuffle(currLevelNodes, random()); - for (Integer currNode : currLevelNodes) { - topDownOrderRandomHnsw.addNode(currLevel, currNode); - } - } - - OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10); - for (int currLevel = 0; currLevel < numLevels; currLevel++) { - for (Integer currNode : nodesPerLevel.get(currLevel)) { - bottomUpExpectedHnsw.addNode(currLevel, currNode); - } - } - - assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size()); - for (Integer node : nodesPerLevel.get(0)) { - assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size()); - } - - for (int currLevel = 1; currLevel < numLevels; currLevel++) { - List expectedNodesOnLevel = nodesPerLevel.get(currLevel); - List sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel); - assertEquals( - String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel), - expectedNodesOnLevel, - sortedNodes); - } - - assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw); - assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw); - assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw); - } - public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException { int totalSize = atLeast(100); int initializerSize = random().nextInt(5, totalSize); @@ -547,7 +476,13 @@ abstract class HnswGraphTestCase extends LuceneTestCase { RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); HnswGraphBuilder finalBuilder = HnswGraphBuilder.create( - finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap); + finalscorerSupplier, + 10, + 30, + seed, + initializerGraph, + initializerOrdMap, + finalVectorValues.size()); // When offset is 0, the graphs should be identical before vectors are added assertGraphEqual(initializerGraph, finalBuilder.getGraph()); @@ -577,7 +512,13 @@ abstract class HnswGraphTestCase extends LuceneTestCase { RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); HnswGraphBuilder finalBuilder = HnswGraphBuilder.create( - finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap); + finalscorerSupplier, + 10, + 30, + seed, + initializerGraph, + initializerOrdMap, + finalVectorValues.size()); assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap); @@ -599,35 +540,30 @@ abstract class HnswGraphTestCase extends LuceneTestCase { } private void assertGraphInitializedFromGraph( - HnswGraph g, HnswGraph h, Map oldToNewOrdMap) throws IOException { - assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); + HnswGraph g, HnswGraph initializer, Map oldToNewOrdMap) throws IOException { + assertEquals( + "the number of levels in the graphs are different!", + initializer.numLevels(), + g.numLevels()); // Confirm that the size of the new graph includes all nodes up to an including the max new // ordinal in the old to // new ordinal mapping - assertEquals( - "the number of nodes in the graphs are different!", - g.size(), - Collections.max(oldToNewOrdMap.values()) + 1); + assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size()); - // assert the nodes from the previous graph are successfully to levels > 0 in the new graph - for (int level = 1; level < g.numLevels(); level++) { - List nodesOnLevel = sortedNodesOnLevel(g, level); - List nodesOnLevel2 = - sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList(); - assertEquals(nodesOnLevel, nodesOnLevel2); - } - - // assert that the neighbors from the old graph are successfully transferred to the new graph + // assert that all the node from initializer graph can be found in the new graph and + // the neighbors from the old graph are successfully transferred to the new graph for (int level = 0; level < g.numLevels(); level++) { - NodesIterator nodesOnLevel = h.getNodesOnLevel(level); + NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level); while (nodesOnLevel.hasNext()) { int node = nodesOnLevel.nextInt(); g.seek(level, oldToNewOrdMap.get(node)); - h.seek(level, node); + initializer.seek(level, node); assertEquals( "arcs differ for node " + node, getNeighborNodes(g), - getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet())); + getNeighborNodes(initializer).stream() + .map(oldToNewOrdMap::get) + .collect(Collectors.toSet())); } } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java new file mode 100644 index 00000000000..23c3511eb2e --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.tests.util.LuceneTestCase; + +/** + * Test OnHeapHnswGraph's behavior specifically, for more complex test, see {@link + * HnswGraphTestCase} + */ +public class TestOnHeapHnswGraph extends LuceneTestCase { + + /* assert exception will be thrown when we add out of bound node to a fixed size graph */ + public void testNoGrowth() { + OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 100); + expectThrows(IllegalStateException.class, () -> graph.addNode(1, 100)); + } + + /* AssertionError will be thrown if we add a node not from top most level, + (likely NPE will be thrown in prod) */ + public void testAddLevelOutOfOrder() { + OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1); + graph.addNode(0, 0); + expectThrows(AssertionError.class, () -> graph.addNode(1, 0)); + } + + /* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */ + public void testIncompleteGraphThrow() { + OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10); + graph.addNode(1, 0); + graph.addNode(0, 0); + assertEquals(1, graph.getNodesOnLevel(1).size()); + graph.addNode(0, 5); + expectThrows(IllegalStateException.class, () -> graph.getNodesOnLevel(0)); + } + + public void testGraphGrowth() { + OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1); + List> levelToNodes = new ArrayList<>(); + int maxLevel = 5; + for (int i = 0; i < maxLevel; i++) { + levelToNodes.add(new ArrayList<>()); + } + for (int i = 0; i < 101; i++) { + int level = random().nextInt(maxLevel); + for (int l = level; l >= 0; l--) { + graph.addNode(l, i); + levelToNodes.get(l).add(i); + } + } + assertGraphEquals(graph, levelToNodes); + } + + public void testGraphBuildOutOfOrder() { + OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1); + List> levelToNodes = new ArrayList<>(); + int maxLevel = 5; + int numNodes = 100; + for (int i = 0; i < maxLevel; i++) { + levelToNodes.add(new ArrayList<>()); + } + int[] insertions = new int[numNodes]; + for (int i = 0; i < numNodes; i++) { + insertions[i] = i; + } + // shuffle the insertion order + for (int i = 0; i < 40; i++) { + int pos1 = random().nextInt(numNodes); + int pos2 = random().nextInt(numNodes); + int tmp = insertions[pos1]; + insertions[pos1] = insertions[pos2]; + insertions[pos2] = tmp; + } + + for (int i : insertions) { + int level = random().nextInt(maxLevel); + for (int l = level; l >= 0; l--) { + graph.addNode(l, i); + levelToNodes.get(l).add(i); + } + } + + for (int i = 0; i < maxLevel; i++) { + levelToNodes.get(i).sort(Integer::compare); + } + + assertGraphEquals(graph, levelToNodes); + } + + private static void assertGraphEquals(OnHeapHnswGraph graph, List> levelToNodes) { + for (int l = 0; l < graph.numLevels(); l++) { + HnswGraph.NodesIterator nodesIterator = graph.getNodesOnLevel(l); + assertEquals(levelToNodes.get(l).size(), nodesIterator.size()); + int idx = 0; + while (nodesIterator.hasNext()) { + assertEquals(levelToNodes.get(l).get(idx++), nodesIterator.next()); + } + } + } +}