Use HashMap (was TreeMap) for OnHeapHnswGraph neighbors

This commit is contained in:
Jonathan Ellis 2023-04-27 10:50:22 -05:00 committed by Michael Sokolov
parent 1fa2be90ea
commit 3c163745bb
7 changed files with 146 additions and 64 deletions

View File

@ -25,6 +25,7 @@ import java.nio.ByteOrder;
import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
@ -36,7 +37,6 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
@ -227,11 +227,10 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
} else {
meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) {
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
meta.writeInt(node); // list of nodes on a level
}
}
@ -257,9 +256,8 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) {
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);

View File

@ -27,6 +27,7 @@ import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
@ -39,7 +40,6 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
@ -261,11 +261,10 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
} else {
meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) {
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
meta.writeInt(node); // list of nodes on a level
}
}
@ -293,9 +292,8 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);

View File

@ -30,6 +30,7 @@ import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
@ -303,9 +304,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
for (int level = 1; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
int[] newNodes = new int[nodesOnLevel.size()];
int n = 0;
while (nodesOnLevel.hasNext()) {
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
}
Arrays.sort(newNodes);
nodesByLevel.add(newNodes);
@ -481,9 +481,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
@ -570,11 +569,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} else {
meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) {
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
meta.writeInt(node); // list of nodes on a level
}
}

View File

@ -315,9 +315,8 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
for (int level = 1; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
int[] newNodes = new int[nodesOnLevel.size()];
int n = 0;
while (nodesOnLevel.hasNext()) {
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
}
Arrays.sort(newNodes);
nodesByLevel.add(newNodes);
@ -677,11 +676,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
int countOnLevel0 = graph.size();
int[][] offsets = new int[graph.numLevels()][];
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
offsets[level] = new int[nodesOnLevel.size()];
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
offsets[level] = new int[sortedNodes.length];
int nodeOffsetId = 0;
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
// Write size in VInt as the neighbors list is typically small
@ -706,6 +704,15 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
return offsets;
}
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
int[] sortedNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
sortedNodes[n] = nodesOnLevel.nextInt();
}
Arrays.sort(sortedNodes);
return sortedNodes;
}
private void writeMeta(
FieldInfo field,
int maxDoc,
@ -779,6 +786,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
if (level > 0) {
int[] nol = new int[nodesOnLevel.size()];
int numberConsumed = nodesOnLevel.consume(nol);
Arrays.sort(nol);
assert numberConsumed == nodesOnLevel.size();
meta.writeVInt(nol.length); // number of nodes on a level
for (int i = nodesOnLevel.size() - 1; i > 0; --i) {

View File

@ -81,7 +81,8 @@ public abstract class HnswGraph {
public abstract int entryNode() throws IOException;
/**
* Get all nodes on a given level as node 0th ordinals
* Get all nodes on a given level as node 0th ordinals. The nodes are NOT guaranteed to be
* presented in any particular order.
*
* @param level level for which to get all nodes
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
@ -123,7 +124,8 @@ public abstract class HnswGraph {
/**
* Iterator over the graph nodes on a certain level, Iterator also provides the size the total
* number of nodes to be iterated over.
* number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any
* particular order.
*/
public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
protected final int size;

View File

@ -20,8 +20,9 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.TreeMap;
import java.util.Map;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
@ -40,12 +41,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
private final List<NeighborArray> graphLevel0;
// Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
// it in this list. However, to avoid changing list indexing, we always will make the first
// element
// null.
private final List<TreeMap<Integer, NeighborArray>> graphUpperLevels;
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
private final int nsize;
private final int nsize0;
@ -76,7 +77,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
if (level == 0) {
return graphLevel0.get(node);
}
TreeMap<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
assert levelMap.containsKey(node);
return levelMap.get(node);
}
@ -103,7 +104,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
// and make this node the graph's new entry point
if (level >= numLevels) {
for (int i = numLevels; i <= level; i++) {
graphUpperLevels.add(new TreeMap<>());
graphUpperLevels.add(new HashMap<>());
}
numLevels = level + 1;
entryNode = node;
@ -204,4 +205,15 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
}
return total;
}
@Override
public String toString() {
return "OnHeapHnswGraph(size="
+ size()
+ ", numLevels="
+ numLevels
+ ", entryNode="
+ entryNode
+ ")";
}
}

View File

@ -29,6 +29,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
@ -265,19 +266,50 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
}
List<Integer> sortedNodesOnLevel(HnswGraph h, int level) throws IOException {
NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
List<Integer> nodes = new ArrayList<>();
while (nodesOnLevel.hasNext()) {
nodes.add(nodesOnLevel.next());
}
Collections.sort(nodes);
return nodes;
}
void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
// construct these up front since they call seek which will mess up our test loop
String prettyG = prettyPrint(g);
String prettyH = prettyPrint(h);
assertEquals(
String.format(
Locale.ROOT,
"the number of levels in the graphs are different:%n%s%n%s",
prettyG,
prettyH),
g.numLevels(),
h.numLevels());
assertEquals(
String.format(
Locale.ROOT,
"the number of nodes in the graphs are different:%n%s%n%s",
prettyG,
prettyH),
g.size(),
h.size());
// assert equal nodes on each level
for (int level = 0; level < g.numLevels(); level++) {
NodesIterator nodesOnLevel = g.getNodesOnLevel(level);
NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level);
while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) {
int node = nodesOnLevel.nextInt();
int node2 = nodesOnLevel2.nextInt();
assertEquals("nodes in the graphs are different", node, node2);
}
List<Integer> hNodes = sortedNodesOnLevel(h, level);
List<Integer> gNodes = sortedNodesOnLevel(g, level);
assertEquals(
String.format(
Locale.ROOT,
"nodes in the graphs are different on level %d:%n%s%n%s",
level,
prettyG,
prettyH),
gNodes,
hNodes);
}
// assert equal nodes' neighbours on each level
@ -287,7 +319,16 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int node = nodesOnLevel.nextInt();
g.seek(level, node);
h.seek(level, node);
assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
assertEquals(
String.format(
Locale.ROOT,
"arcs differ for node %d on level %d:%n%s%n%s",
node,
level,
prettyG,
prettyH),
getNeighborNodes(g),
getNeighborNodes(h));
}
}
}
@ -495,14 +536,12 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel);
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
assertEquals(expectedNodesOnLevel.size(), nodesIterator.size());
for (Integer expectedNode : expectedNodesOnLevel) {
int currentNode = nodesIterator.nextInt();
assertEquals(expectedNode.intValue(), currentNode);
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size());
}
List<Integer> sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel);
assertEquals(
String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel),
expectedNodesOnLevel,
sortedNodes);
}
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
@ -607,13 +646,10 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph
for (int level = 1; level < g.numLevels(); level++) {
NodesIterator nodesOnLevel = g.getNodesOnLevel(level);
NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level);
while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) {
int node = nodesOnLevel.nextInt();
int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt());
assertEquals("nodes in the graphs are different", node, node2);
}
List<Integer> nodesOnLevel = sortedNodesOnLevel(g, level);
List<Integer> nodesOnLevel2 =
sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList();
assertEquals(nodesOnLevel, nodesOnLevel2);
}
// assert that the neighbors from the old graph are successfully transferred to the new graph
@ -1196,4 +1232,34 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
return bvec;
}
static String prettyPrint(HnswGraph hnsw) {
StringBuilder sb = new StringBuilder();
sb.append(hnsw);
sb.append("\n");
try {
for (int level = 0; level < hnsw.numLevels(); level++) {
sb.append("# Level ").append(level).append("\n");
NodesIterator it = hnsw.getNodesOnLevel(level);
while (it.hasNext()) {
int node = it.nextInt();
sb.append(" ").append(node).append(" -> ");
hnsw.seek(level, node);
while (true) {
int neighbor = hnsw.nextNeighbor();
if (neighbor == NO_MORE_DOCS) {
break;
}
sb.append(" ").append(neighbor);
}
sb.append("\n");
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return sb.toString();
}
}