mirror of https://github.com/apache/lucene.git
Use HashMap (was TreeMap) for OnHeapHnswGraph neighbors
This commit is contained in:
parent
1fa2be90ea
commit
3c163745bb
|
@ -25,6 +25,7 @@ import java.nio.ByteOrder;
|
|||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -36,7 +37,6 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
|
@ -227,11 +227,10 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
} else {
|
||||
meta.writeInt(graph.numLevels());
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
meta.writeInt(sortedNodes.length); // number of nodes on a level
|
||||
if (level > 0) {
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
for (int node : sortedNodes) {
|
||||
meta.writeInt(node); // list of nodes on a level
|
||||
}
|
||||
}
|
||||
|
@ -257,9 +256,8 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
for (int node : sortedNodes) {
|
||||
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Arrays;
|
|||
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -39,7 +40,6 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
|
@ -261,11 +261,10 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
} else {
|
||||
meta.writeInt(graph.numLevels());
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
meta.writeInt(sortedNodes.length); // number of nodes on a level
|
||||
if (level > 0) {
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
for (int node : sortedNodes) {
|
||||
meta.writeInt(node); // list of nodes on a level
|
||||
}
|
||||
}
|
||||
|
@ -293,9 +292,8 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
for (int node : sortedNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.codecs.CodecUtil;
|
|||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -303,9 +304,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
for (int level = 1; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
int[] newNodes = new int[nodesOnLevel.size()];
|
||||
int n = 0;
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
|
||||
}
|
||||
Arrays.sort(newNodes);
|
||||
nodesByLevel.add(newNodes);
|
||||
|
@ -481,9 +481,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
for (int node : sortedNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
|
@ -570,11 +569,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
} else {
|
||||
meta.writeInt(graph.numLevels());
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
|
||||
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
|
||||
meta.writeInt(sortedNodes.length); // number of nodes on a level
|
||||
if (level > 0) {
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
for (int node : sortedNodes) {
|
||||
meta.writeInt(node); // list of nodes on a level
|
||||
}
|
||||
}
|
||||
|
|
|
@ -315,9 +315,8 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
for (int level = 1; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
int[] newNodes = new int[nodesOnLevel.size()];
|
||||
int n = 0;
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
|
||||
}
|
||||
Arrays.sort(newNodes);
|
||||
nodesByLevel.add(newNodes);
|
||||
|
@ -677,11 +676,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int countOnLevel0 = graph.size();
|
||||
int[][] offsets = new int[graph.numLevels()][];
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
offsets[level] = new int[nodesOnLevel.size()];
|
||||
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
|
||||
offsets[level] = new int[sortedNodes.length];
|
||||
int nodeOffsetId = 0;
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
for (int node : sortedNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
// Write size in VInt as the neighbors list is typically small
|
||||
|
@ -706,6 +704,15 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
return offsets;
|
||||
}
|
||||
|
||||
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
|
||||
int[] sortedNodes = new int[nodesOnLevel.size()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
sortedNodes[n] = nodesOnLevel.nextInt();
|
||||
}
|
||||
Arrays.sort(sortedNodes);
|
||||
return sortedNodes;
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
FieldInfo field,
|
||||
int maxDoc,
|
||||
|
@ -779,6 +786,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
if (level > 0) {
|
||||
int[] nol = new int[nodesOnLevel.size()];
|
||||
int numberConsumed = nodesOnLevel.consume(nol);
|
||||
Arrays.sort(nol);
|
||||
assert numberConsumed == nodesOnLevel.size();
|
||||
meta.writeVInt(nol.length); // number of nodes on a level
|
||||
for (int i = nodesOnLevel.size() - 1; i > 0; --i) {
|
||||
|
|
|
@ -81,7 +81,8 @@ public abstract class HnswGraph {
|
|||
public abstract int entryNode() throws IOException;
|
||||
|
||||
/**
|
||||
* Get all nodes on a given level as node 0th ordinals
|
||||
* Get all nodes on a given level as node 0th ordinals. The nodes are NOT guaranteed to be
|
||||
* presented in any particular order.
|
||||
*
|
||||
* @param level level for which to get all nodes
|
||||
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
|
||||
|
@ -123,7 +124,8 @@ public abstract class HnswGraph {
|
|||
|
||||
/**
|
||||
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
|
||||
* number of nodes to be iterated over.
|
||||
* number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any
|
||||
* particular order.
|
||||
*/
|
||||
public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
|
||||
protected final int size;
|
||||
|
|
|
@ -20,8 +20,9 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.TreeMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
|
@ -40,12 +41,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<NeighborArray> graphLevel0;
|
||||
// Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0
|
||||
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
|
||||
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
|
||||
// it in this list. However, to avoid changing list indexing, we always will make the first
|
||||
// element
|
||||
// null.
|
||||
private final List<TreeMap<Integer, NeighborArray>> graphUpperLevels;
|
||||
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
|
||||
private final int nsize;
|
||||
private final int nsize0;
|
||||
|
||||
|
@ -76,7 +77,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
if (level == 0) {
|
||||
return graphLevel0.get(node);
|
||||
}
|
||||
TreeMap<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
|
||||
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
|
||||
assert levelMap.containsKey(node);
|
||||
return levelMap.get(node);
|
||||
}
|
||||
|
@ -103,7 +104,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
// and make this node the graph's new entry point
|
||||
if (level >= numLevels) {
|
||||
for (int i = numLevels; i <= level; i++) {
|
||||
graphUpperLevels.add(new TreeMap<>());
|
||||
graphUpperLevels.add(new HashMap<>());
|
||||
}
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
|
@ -204,4 +205,15 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OnHeapHnswGraph(size="
|
||||
+ size()
|
||||
+ ", numLevels="
|
||||
+ numLevels
|
||||
+ ", entryNode="
|
||||
+ entryNode
|
||||
+ ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
|
@ -265,19 +266,50 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
List<Integer> sortedNodesOnLevel(HnswGraph h, int level) throws IOException {
|
||||
NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
|
||||
List<Integer> nodes = new ArrayList<>();
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
nodes.add(nodesOnLevel.next());
|
||||
}
|
||||
Collections.sort(nodes);
|
||||
return nodes;
|
||||
}
|
||||
|
||||
void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
||||
// construct these up front since they call seek which will mess up our test loop
|
||||
String prettyG = prettyPrint(g);
|
||||
String prettyH = prettyPrint(h);
|
||||
assertEquals(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"the number of levels in the graphs are different:%n%s%n%s",
|
||||
prettyG,
|
||||
prettyH),
|
||||
g.numLevels(),
|
||||
h.numLevels());
|
||||
assertEquals(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"the number of nodes in the graphs are different:%n%s%n%s",
|
||||
prettyG,
|
||||
prettyH),
|
||||
g.size(),
|
||||
h.size());
|
||||
|
||||
// assert equal nodes on each level
|
||||
for (int level = 0; level < g.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = g.getNodesOnLevel(level);
|
||||
NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
int node2 = nodesOnLevel2.nextInt();
|
||||
assertEquals("nodes in the graphs are different", node, node2);
|
||||
}
|
||||
List<Integer> hNodes = sortedNodesOnLevel(h, level);
|
||||
List<Integer> gNodes = sortedNodesOnLevel(g, level);
|
||||
assertEquals(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"nodes in the graphs are different on level %d:%n%s%n%s",
|
||||
level,
|
||||
prettyG,
|
||||
prettyH),
|
||||
gNodes,
|
||||
hNodes);
|
||||
}
|
||||
|
||||
// assert equal nodes' neighbours on each level
|
||||
|
@ -287,7 +319,16 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int node = nodesOnLevel.nextInt();
|
||||
g.seek(level, node);
|
||||
h.seek(level, node);
|
||||
assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
|
||||
assertEquals(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"arcs differ for node %d on level %d:%n%s%n%s",
|
||||
node,
|
||||
level,
|
||||
prettyG,
|
||||
prettyH),
|
||||
getNeighborNodes(g),
|
||||
getNeighborNodes(h));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -495,14 +536,12 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
|
||||
NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel);
|
||||
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
|
||||
assertEquals(expectedNodesOnLevel.size(), nodesIterator.size());
|
||||
for (Integer expectedNode : expectedNodesOnLevel) {
|
||||
int currentNode = nodesIterator.nextInt();
|
||||
assertEquals(expectedNode.intValue(), currentNode);
|
||||
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size());
|
||||
}
|
||||
List<Integer> sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel);
|
||||
assertEquals(
|
||||
String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel),
|
||||
expectedNodesOnLevel,
|
||||
sortedNodes);
|
||||
}
|
||||
|
||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
|
||||
|
@ -607,13 +646,10 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph
|
||||
for (int level = 1; level < g.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = g.getNodesOnLevel(level);
|
||||
NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt());
|
||||
assertEquals("nodes in the graphs are different", node, node2);
|
||||
}
|
||||
List<Integer> nodesOnLevel = sortedNodesOnLevel(g, level);
|
||||
List<Integer> nodesOnLevel2 =
|
||||
sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList();
|
||||
assertEquals(nodesOnLevel, nodesOnLevel2);
|
||||
}
|
||||
|
||||
// assert that the neighbors from the old graph are successfully transferred to the new graph
|
||||
|
@ -1196,4 +1232,34 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
return bvec;
|
||||
}
|
||||
|
||||
static String prettyPrint(HnswGraph hnsw) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(hnsw);
|
||||
sb.append("\n");
|
||||
|
||||
try {
|
||||
for (int level = 0; level < hnsw.numLevels(); level++) {
|
||||
sb.append("# Level ").append(level).append("\n");
|
||||
NodesIterator it = hnsw.getNodesOnLevel(level);
|
||||
while (it.hasNext()) {
|
||||
int node = it.nextInt();
|
||||
sb.append(" ").append(node).append(" -> ");
|
||||
hnsw.seek(level, node);
|
||||
while (true) {
|
||||
int neighbor = hnsw.nextNeighbor();
|
||||
if (neighbor == NO_MORE_DOCS) {
|
||||
break;
|
||||
}
|
||||
sb.append(" ").append(neighbor);
|
||||
}
|
||||
sb.append("\n");
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue