mirror of https://github.com/apache/lucene.git
Optimize OnHeapHnswGraph's data structure (#12651)
make the internal graph representation a 2d array
This commit is contained in:
parent
2482f7688b
commit
a1cf22e6a9
|
@ -195,6 +195,8 @@ Optimizations
|
|||
|
||||
* GITHUB#12668: ImpactsEnums now decode frequencies lazily like PostingsEnums.
|
||||
(Adrien Grand)
|
||||
|
||||
* GITHUB#12651: Use 2d array for OnHeapHnswGraph representation. (Patrick Zhai)
|
||||
|
||||
Changes in runtime behavior
|
||||
---------------------
|
||||
|
|
|
@ -56,7 +56,11 @@ public class Word2VecSynonymProvider {
|
|||
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
|
||||
HnswGraphBuilder builder =
|
||||
HnswGraphBuilder.create(
|
||||
scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed);
|
||||
scorerSupplier,
|
||||
DEFAULT_MAX_CONN,
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
HnswGraphBuilder.randSeed,
|
||||
word2VecModel.size());
|
||||
this.hnswGraph = builder.build(word2VecModel.size());
|
||||
}
|
||||
|
||||
|
|
|
@ -438,7 +438,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
RandomVectorScorerSupplier.createBytes(
|
||||
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
|
||||
createHnswGraphBuilder(
|
||||
mergeState,
|
||||
fieldInfo,
|
||||
scorerSupplier,
|
||||
initializerIndex,
|
||||
vectorValues.size());
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.size());
|
||||
}
|
||||
|
@ -453,7 +458,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
RandomVectorScorerSupplier.createFloats(
|
||||
vectorValues, fieldInfo.getVectorSimilarityFunction());
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
|
||||
createHnswGraphBuilder(
|
||||
mergeState,
|
||||
fieldInfo,
|
||||
scorerSupplier,
|
||||
initializerIndex,
|
||||
vectorValues.size());
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.size());
|
||||
}
|
||||
|
@ -488,10 +498,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
MergeState mergeState,
|
||||
FieldInfo fieldInfo,
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int initializerIndex)
|
||||
int initializerIndex,
|
||||
int graphSize)
|
||||
throws IOException {
|
||||
if (initializerIndex == -1) {
|
||||
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
return HnswGraphBuilder.create(
|
||||
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, graphSize);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph =
|
||||
|
@ -499,7 +511,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
Map<Integer, Integer> ordinalMapper =
|
||||
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
|
||||
return HnswGraphBuilder.create(
|
||||
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper);
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed,
|
||||
initializerGraph,
|
||||
ordinalMapper,
|
||||
graphSize);
|
||||
}
|
||||
|
||||
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
|
||||
|
|
|
@ -66,6 +66,11 @@ public abstract class HnswGraph {
|
|||
/** Returns the number of nodes in the graph */
|
||||
public abstract int size();
|
||||
|
||||
/** Returns max node id, inclusive, normally this value will be size - 1 */
|
||||
public int maxNodeId() {
|
||||
return size() - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
||||
|
|
|
@ -76,7 +76,12 @@ public final class HnswGraphBuilder {
|
|||
public static HnswGraphBuilder create(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
||||
throws IOException {
|
||||
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
|
||||
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1);
|
||||
}
|
||||
|
||||
public static HnswGraphBuilder create(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
|
||||
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||
}
|
||||
|
||||
public static HnswGraphBuilder create(
|
||||
|
@ -85,9 +90,11 @@ public final class HnswGraphBuilder {
|
|||
int beamWidth,
|
||||
long seed,
|
||||
HnswGraph initializerGraph,
|
||||
Map<Integer, Integer> oldToNewOrdinalMap)
|
||||
Map<Integer, Integer> oldToNewOrdinalMap,
|
||||
int graphSize)
|
||||
throws IOException {
|
||||
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
|
||||
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
|
||||
return hnswGraphBuilder;
|
||||
}
|
||||
|
@ -102,10 +109,10 @@ public final class HnswGraphBuilder {
|
|||
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||
* to ensure repeatable construction.
|
||||
* @param graphSize size of graph, if unknown, pass in -1
|
||||
*/
|
||||
private HnswGraphBuilder(
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
||||
throws IOException {
|
||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) {
|
||||
if (M <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
}
|
||||
|
@ -118,7 +125,7 @@ public final class HnswGraphBuilder {
|
|||
// normalization factor for level generation; currently not configurable
|
||||
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
this.hnsw = new OnHeapHnswGraph(M);
|
||||
this.hnsw = new OnHeapHnswGraph(M, graphSize);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
|
||||
|
@ -155,7 +162,7 @@ public final class HnswGraphBuilder {
|
|||
private void initializeFromGraph(
|
||||
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
|
||||
assert hnsw.size() == 0;
|
||||
for (int level = 0; level < initializerGraph.numLevels(); level++) {
|
||||
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
|
||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||
|
||||
while (it.hasNext()) {
|
||||
|
@ -288,7 +295,7 @@ public final class HnswGraphBuilder {
|
|||
// only adding it if it is closer to the target than to any of the other selected neighbors
|
||||
int cNode = candidates.node[i];
|
||||
float cScore = candidates.score[i];
|
||||
assert cNode < hnsw.size();
|
||||
assert cNode <= hnsw.maxNodeId();
|
||||
if (diversityCheck(cNode, cScore, neighbors)) {
|
||||
neighbors.addInOrder(cNode, cScore);
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ public class HnswGraphSearcher {
|
|||
throws IOException {
|
||||
HnswGraphSearcher graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size()));
|
||||
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
|
||||
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ public class HnswGraphSearcher {
|
|||
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
|
||||
OnHeapHnswGraphSearcher graphSearcher =
|
||||
new OnHeapHnswGraphSearcher(
|
||||
new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size()));
|
||||
new NeighborQueue(topK, true), new SparseFixedBitSet(getGraphSize(graph)));
|
||||
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
|
||||
return knnCollector;
|
||||
}
|
||||
|
@ -150,9 +150,9 @@ public class HnswGraphSearcher {
|
|||
*/
|
||||
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
|
||||
throws IOException {
|
||||
int size = graph.size();
|
||||
int size = getGraphSize(graph);
|
||||
int visitedCount = 1;
|
||||
prepareScratchState(graph.size());
|
||||
prepareScratchState(size);
|
||||
int currentEp = graph.entryNode();
|
||||
float currentScore = scorer.score(currentEp);
|
||||
boolean foundBetter;
|
||||
|
@ -201,8 +201,9 @@ public class HnswGraphSearcher {
|
|||
Bits acceptOrds)
|
||||
throws IOException {
|
||||
|
||||
int size = graph.size();
|
||||
prepareScratchState(graph.size());
|
||||
int size = getGraphSize(graph);
|
||||
|
||||
prepareScratchState(size);
|
||||
|
||||
for (int ep : eps) {
|
||||
if (visited.getAndSet(ep) == false) {
|
||||
|
@ -284,6 +285,10 @@ public class HnswGraphSearcher {
|
|||
return graph.nextNeighbor();
|
||||
}
|
||||
|
||||
private static int getGraphSize(HnswGraph graph) {
|
||||
return graph.maxNodeId() + 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding
|
||||
* the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and
|
||||
|
|
|
@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
|
@ -32,39 +31,56 @@ import org.apache.lucene.util.RamUsageEstimator;
|
|||
*/
|
||||
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||
|
||||
private static final int INIT_SIZE = 128;
|
||||
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level. -1 if not set
|
||||
|
||||
// Level 0 is represented as List<NeighborArray> – nodes' connections on level 0.
|
||||
// Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
|
||||
// to vectors
|
||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<NeighborArray> graphLevel0;
|
||||
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
|
||||
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
|
||||
// it in this list. However, to avoid changing list indexing, we always will make the first
|
||||
// element
|
||||
// null.
|
||||
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
|
||||
private final int nsize;
|
||||
private final int nsize0;
|
||||
// the internal graph representation where the first dimension is node id and second dimension is
|
||||
// level
|
||||
// e.g. graph[1][2] is all the neighbours of node 1 at level 2
|
||||
private NeighborArray[][] graph;
|
||||
// essentially another 2d map which the first dimension is level and second dimension is node id,
|
||||
// this is only
|
||||
// generated on demand when there's someone calling getNodeOnLevel on a non-zero level
|
||||
private List<Integer>[] levelToNodes;
|
||||
private int
|
||||
lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
|
||||
// levelToNodes
|
||||
private int size; // graph size, which is number of nodes in level 0
|
||||
private int
|
||||
nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
|
||||
// is only used to account memory usage
|
||||
private int maxNodeId;
|
||||
private final int nsize; // neighbour array size at non-zero level
|
||||
private final int nsize0; // neighbour array size at zero level
|
||||
private final boolean
|
||||
noGrowth; // if an initial size is passed in, we don't expect the graph to grow itself
|
||||
|
||||
// KnnGraphValues iterator members
|
||||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
OnHeapHnswGraph(int M) {
|
||||
/**
|
||||
* ctor
|
||||
*
|
||||
* @param numNodes number of nodes that will be added to this graph, passing in -1 means unbounded
|
||||
* while passing in a non-negative value will lock the whole graph and disable the graph from
|
||||
* growing itself (you cannot add a node with has id >= numNodes)
|
||||
*/
|
||||
OnHeapHnswGraph(int M, int numNodes) {
|
||||
this.numLevels = 1; // Implicitly start the graph with a single level
|
||||
this.graphLevel0 = new ArrayList<>();
|
||||
this.entryNode = -1; // Entry node should be negative until a node is added
|
||||
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
|
||||
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
|
||||
this.maxNodeId = -1;
|
||||
this.nsize = M + 1;
|
||||
this.nsize0 = (M * 2 + 1);
|
||||
|
||||
this.graphUpperLevels = new ArrayList<>(numLevels);
|
||||
graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||
noGrowth = numNodes != -1;
|
||||
if (noGrowth == false) {
|
||||
numNodes = INIT_SIZE;
|
||||
}
|
||||
this.graph = new NeighborArray[numNodes][];
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -74,23 +90,32 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public NeighborArray getNeighbors(int level, int node) {
|
||||
if (level == 0) {
|
||||
return graphLevel0.get(node);
|
||||
}
|
||||
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
|
||||
assert levelMap.containsKey(node);
|
||||
return levelMap.get(node);
|
||||
assert graph[node][level] != null;
|
||||
return graph[node][level];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graphLevel0.size(); // all nodes are located on the 0th level
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* When we initialize from another graph, the max node id is different from {@link #size()},
|
||||
* because we will add nodes out of order, such that we need two method for each
|
||||
*
|
||||
* @return max node id (inclusive)
|
||||
*/
|
||||
@Override
|
||||
public int maxNodeId() {
|
||||
return maxNodeId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
|
||||
* preceded by the node inserted out of order are eventually added.
|
||||
*
|
||||
* <p>NOTE: You must add a node starting from the node's top level
|
||||
*
|
||||
* @param level level to add a node on
|
||||
* @param node the node to add, represented as an ordinal on the level 0.
|
||||
*/
|
||||
|
@ -99,28 +124,33 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
entryNode = node;
|
||||
}
|
||||
|
||||
if (level > 0) {
|
||||
// if the new node introduces a new level, add more levels to the graph,
|
||||
// and make this node the graph's new entry point
|
||||
if (level >= numLevels) {
|
||||
for (int i = numLevels; i <= level; i++) {
|
||||
graphUpperLevels.add(new HashMap<>());
|
||||
}
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
}
|
||||
|
||||
graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true));
|
||||
} else {
|
||||
// Add nodes all the way up to and including "node" in the new graph on level 0. This will
|
||||
// cause the size of the
|
||||
// graph to differ from the number of nodes added to the graph. The size of the graph and the
|
||||
// number of nodes
|
||||
// added will only be in sync once all nodes from 0...last_node are added into the graph.
|
||||
while (node >= graphLevel0.size()) {
|
||||
graphLevel0.add(new NeighborArray(nsize0, true));
|
||||
if (node >= graph.length) {
|
||||
if (noGrowth) {
|
||||
throw new IllegalStateException(
|
||||
"The graph does not expect to grow when an initial size is given");
|
||||
}
|
||||
graph = ArrayUtil.grow(graph, node + 1);
|
||||
}
|
||||
|
||||
if (level >= numLevels) {
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
}
|
||||
|
||||
assert graph[node] == null || graph[node].length > level
|
||||
: "node must be inserted from the top level";
|
||||
if (graph[node] == null) {
|
||||
graph[node] =
|
||||
new NeighborArray[level + 1]; // assumption: we always call this function from top level
|
||||
size++;
|
||||
}
|
||||
if (level == 0) {
|
||||
graph[node][level] = new NeighborArray(nsize0, true);
|
||||
} else {
|
||||
graph[node][level] = new NeighborArray(nsize, true);
|
||||
nonZeroLevelSize++;
|
||||
}
|
||||
maxNodeId = Math.max(maxNodeId, node);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -158,50 +188,83 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
return entryNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* WARN: calling this method will essentially iterate through all nodes at level 0 (even if you're
|
||||
* not getting node at level 0), we have built some caching mechanism such that if graph is not
|
||||
* changed only the first non-zero level call will pay the cost. So it is highly NOT recommended
|
||||
* to call this method while the graph is still building.
|
||||
*
|
||||
* <p>NOTE: calling this method while the graph is still building is prohibited
|
||||
*/
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (size() != maxNodeId() + 1) {
|
||||
throw new IllegalStateException(
|
||||
"graph build not complete, size=" + size() + " maxNodeId=" + maxNodeId());
|
||||
}
|
||||
if (level == 0) {
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new CollectionNodesIterator(graphUpperLevels.get(level).keySet());
|
||||
generateLevelToNodes();
|
||||
return new CollectionNodesIterator(levelToNodes[level]);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings({"unchecked", "rawtypes"})
|
||||
private void generateLevelToNodes() {
|
||||
if (lastFreezeSize == size) {
|
||||
return;
|
||||
}
|
||||
|
||||
levelToNodes = new List[numLevels];
|
||||
for (int i = 1; i < numLevels; i++) {
|
||||
levelToNodes[i] = new ArrayList<>();
|
||||
}
|
||||
int nonNullNode = 0;
|
||||
for (int node = 0; node < graph.length; node++) {
|
||||
// when we init from another graph, we could have holes where some slot is null
|
||||
if (graph[node] == null) {
|
||||
continue;
|
||||
}
|
||||
nonNullNode++;
|
||||
for (int i = 1; i < graph[node].length; i++) {
|
||||
levelToNodes[i].add(node);
|
||||
}
|
||||
if (nonNullNode == size) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
lastFreezeSize = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long neighborArrayBytes0 =
|
||||
nsize0 * (Integer.BYTES + Float.BYTES)
|
||||
(long) nsize0 * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
|
||||
+ Integer.BYTES * 3;
|
||||
long neighborArrayBytes =
|
||||
nsize * (Integer.BYTES + Float.BYTES)
|
||||
(long) nsize * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
|
||||
+ Integer.BYTES * 3;
|
||||
long total = 0;
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
if (l == 0) {
|
||||
total +=
|
||||
graphLevel0.size() * neighborArrayBytes0
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
|
||||
} else {
|
||||
long numNodesOnLevel = graphUpperLevels.get(l).size();
|
||||
|
||||
// For levels > 0, we represent the graph structure with a tree map.
|
||||
// A single node in the tree contains 3 references (left root, right root, value) as well
|
||||
// as an Integer for the key and 1 extra byte for the color of the node (this is actually 1
|
||||
// bit, but
|
||||
// because we do not have that granularity, we set to 1 byte). In addition, we include 1
|
||||
// more reference for
|
||||
// the tree map itself.
|
||||
total +=
|
||||
numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1)
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
|
||||
// Add the size neighbor of each node
|
||||
total += numNodesOnLevel * neighborArrayBytes;
|
||||
}
|
||||
total +=
|
||||
size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
|
||||
total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
|
||||
total += 8 * Integer.BYTES; // all int fields
|
||||
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
|
||||
total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
|
||||
if (levelToNodes != null) {
|
||||
total +=
|
||||
(long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
|
||||
total +=
|
||||
(long) nonZeroLevelSize
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||
+ Integer.BYTES);
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
|
|
@ -195,7 +195,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
// test that sorted index returns the same search results are unsorted
|
||||
// test that sorted index returns the same search results as unsorted
|
||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||
int dim = random().nextInt(10) + 3;
|
||||
int nDoc = random().nextInt(200) + 100;
|
||||
|
@ -454,77 +454,6 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException {
|
||||
int maxNumLevels = randomIntBetween(2, 10);
|
||||
int nodeCount = randomIntBetween(1, 100);
|
||||
|
||||
List<List<Integer>> nodesPerLevel = new ArrayList<>();
|
||||
for (int i = 0; i < maxNumLevels; i++) {
|
||||
nodesPerLevel.add(new ArrayList<>());
|
||||
}
|
||||
|
||||
int numLevels = 0;
|
||||
for (int currNode = 0; currNode < nodeCount; currNode++) {
|
||||
int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1);
|
||||
numLevels = Math.max(numLevels, nodeMaxLevel);
|
||||
for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) {
|
||||
nodesPerLevel.get(currLevel).add(currNode);
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
|
||||
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
|
||||
int currLevelNodesSize = currLevelNodes.size();
|
||||
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
|
||||
topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
|
||||
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
|
||||
int currLevelNodesSize = currLevelNodes.size();
|
||||
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
|
||||
bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
|
||||
List<Integer> currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel));
|
||||
Collections.shuffle(currLevelNodes, random());
|
||||
for (Integer currNode : currLevelNodes) {
|
||||
topDownOrderRandomHnsw.addNode(currLevel, currNode);
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
|
||||
for (Integer currNode : nodesPerLevel.get(currLevel)) {
|
||||
bottomUpExpectedHnsw.addNode(currLevel, currNode);
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size());
|
||||
for (Integer node : nodesPerLevel.get(0)) {
|
||||
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size());
|
||||
}
|
||||
|
||||
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
|
||||
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
|
||||
List<Integer> sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel);
|
||||
assertEquals(
|
||||
String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel),
|
||||
expectedNodesOnLevel,
|
||||
sortedNodes);
|
||||
}
|
||||
|
||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
|
||||
assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw);
|
||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw);
|
||||
}
|
||||
|
||||
public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
|
||||
int totalSize = atLeast(100);
|
||||
int initializerSize = random().nextInt(5, totalSize);
|
||||
|
@ -547,7 +476,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||
HnswGraphBuilder finalBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
|
||||
finalscorerSupplier,
|
||||
10,
|
||||
30,
|
||||
seed,
|
||||
initializerGraph,
|
||||
initializerOrdMap,
|
||||
finalVectorValues.size());
|
||||
|
||||
// When offset is 0, the graphs should be identical before vectors are added
|
||||
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
|
||||
|
@ -577,7 +512,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||
HnswGraphBuilder finalBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
|
||||
finalscorerSupplier,
|
||||
10,
|
||||
30,
|
||||
seed,
|
||||
initializerGraph,
|
||||
initializerOrdMap,
|
||||
finalVectorValues.size());
|
||||
|
||||
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
|
||||
|
||||
|
@ -599,35 +540,30 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
private void assertGraphInitializedFromGraph(
|
||||
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||
HnswGraph g, HnswGraph initializer, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||
assertEquals(
|
||||
"the number of levels in the graphs are different!",
|
||||
initializer.numLevels(),
|
||||
g.numLevels());
|
||||
// Confirm that the size of the new graph includes all nodes up to an including the max new
|
||||
// ordinal in the old to
|
||||
// new ordinal mapping
|
||||
assertEquals(
|
||||
"the number of nodes in the graphs are different!",
|
||||
g.size(),
|
||||
Collections.max(oldToNewOrdMap.values()) + 1);
|
||||
assertEquals("the number of nodes in the graphs are different!", initializer.size(), g.size());
|
||||
|
||||
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph
|
||||
for (int level = 1; level < g.numLevels(); level++) {
|
||||
List<Integer> nodesOnLevel = sortedNodesOnLevel(g, level);
|
||||
List<Integer> nodesOnLevel2 =
|
||||
sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList();
|
||||
assertEquals(nodesOnLevel, nodesOnLevel2);
|
||||
}
|
||||
|
||||
// assert that the neighbors from the old graph are successfully transferred to the new graph
|
||||
// assert that all the node from initializer graph can be found in the new graph and
|
||||
// the neighbors from the old graph are successfully transferred to the new graph
|
||||
for (int level = 0; level < g.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
|
||||
NodesIterator nodesOnLevel = initializer.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
g.seek(level, oldToNewOrdMap.get(node));
|
||||
h.seek(level, node);
|
||||
initializer.seek(level, node);
|
||||
assertEquals(
|
||||
"arcs differ for node " + node,
|
||||
getNeighborNodes(g),
|
||||
getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet()));
|
||||
getNeighborNodes(initializer).stream()
|
||||
.map(oldToNewOrdMap::get)
|
||||
.collect(Collectors.toSet()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
/**
|
||||
* Test OnHeapHnswGraph's behavior specifically, for more complex test, see {@link
|
||||
* HnswGraphTestCase}
|
||||
*/
|
||||
public class TestOnHeapHnswGraph extends LuceneTestCase {
|
||||
|
||||
/* assert exception will be thrown when we add out of bound node to a fixed size graph */
|
||||
public void testNoGrowth() {
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 100);
|
||||
expectThrows(IllegalStateException.class, () -> graph.addNode(1, 100));
|
||||
}
|
||||
|
||||
/* AssertionError will be thrown if we add a node not from top most level,
|
||||
(likely NPE will be thrown in prod) */
|
||||
public void testAddLevelOutOfOrder() {
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||
graph.addNode(0, 0);
|
||||
expectThrows(AssertionError.class, () -> graph.addNode(1, 0));
|
||||
}
|
||||
|
||||
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
|
||||
public void testIncompleteGraphThrow() {
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
|
||||
graph.addNode(1, 0);
|
||||
graph.addNode(0, 0);
|
||||
assertEquals(1, graph.getNodesOnLevel(1).size());
|
||||
graph.addNode(0, 5);
|
||||
expectThrows(IllegalStateException.class, () -> graph.getNodesOnLevel(0));
|
||||
}
|
||||
|
||||
public void testGraphGrowth() {
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||
List<List<Integer>> levelToNodes = new ArrayList<>();
|
||||
int maxLevel = 5;
|
||||
for (int i = 0; i < maxLevel; i++) {
|
||||
levelToNodes.add(new ArrayList<>());
|
||||
}
|
||||
for (int i = 0; i < 101; i++) {
|
||||
int level = random().nextInt(maxLevel);
|
||||
for (int l = level; l >= 0; l--) {
|
||||
graph.addNode(l, i);
|
||||
levelToNodes.get(l).add(i);
|
||||
}
|
||||
}
|
||||
assertGraphEquals(graph, levelToNodes);
|
||||
}
|
||||
|
||||
public void testGraphBuildOutOfOrder() {
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||
List<List<Integer>> levelToNodes = new ArrayList<>();
|
||||
int maxLevel = 5;
|
||||
int numNodes = 100;
|
||||
for (int i = 0; i < maxLevel; i++) {
|
||||
levelToNodes.add(new ArrayList<>());
|
||||
}
|
||||
int[] insertions = new int[numNodes];
|
||||
for (int i = 0; i < numNodes; i++) {
|
||||
insertions[i] = i;
|
||||
}
|
||||
// shuffle the insertion order
|
||||
for (int i = 0; i < 40; i++) {
|
||||
int pos1 = random().nextInt(numNodes);
|
||||
int pos2 = random().nextInt(numNodes);
|
||||
int tmp = insertions[pos1];
|
||||
insertions[pos1] = insertions[pos2];
|
||||
insertions[pos2] = tmp;
|
||||
}
|
||||
|
||||
for (int i : insertions) {
|
||||
int level = random().nextInt(maxLevel);
|
||||
for (int l = level; l >= 0; l--) {
|
||||
graph.addNode(l, i);
|
||||
levelToNodes.get(l).add(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < maxLevel; i++) {
|
||||
levelToNodes.get(i).sort(Integer::compare);
|
||||
}
|
||||
|
||||
assertGraphEquals(graph, levelToNodes);
|
||||
}
|
||||
|
||||
private static void assertGraphEquals(OnHeapHnswGraph graph, List<List<Integer>> levelToNodes) {
|
||||
for (int l = 0; l < graph.numLevels(); l++) {
|
||||
HnswGraph.NodesIterator nodesIterator = graph.getNodesOnLevel(l);
|
||||
assertEquals(levelToNodes.get(l).size(), nodesIterator.size());
|
||||
int idx = 0;
|
||||
while (nodesIterator.hasNext()) {
|
||||
assertEquals(levelToNodes.get(l).get(idx++), nodesIterator.next());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue