LUCENE-10054 Make HnswGraph hierarchical (#250)

Currently HNSW has only a single layer.
This is the first part to make it multi-layered.

To keep changes small, this PR only adds
 multiple layers in the HnswGraph class.

TODO  for following PRs:
- modify graph construction and search algorithm for a hierarchical
graph.
- modify Lucene90HnswVectorsWriter and Lucene90HnswVectorsReader to
write and read multiple layers\
This commit is contained in:
Mayya Sharipova 2021-08-23 15:54:26 -04:00 committed by GitHub
parent 46fa09d265
commit 257d256def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 46 deletions

View File

@ -481,7 +481,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public void seek(int targetOrd) throws IOException { public void seek(int level, int targetOrd) throws IOException {
// unsafe; no bounds checking // unsafe; no bounds checking
dataIn.seek(entry.ordOffsets[targetOrd]); dataIn.seek(entry.ordOffsets[targetOrd]);
arcCount = dataIn.readInt(); arcCount = dataIn.readInt();

View File

@ -208,11 +208,12 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
// TODO: implement storing of hierarchical graph; for now stores only 0th level
for (int ord = 0; ord < count; ord++) { for (int ord = 0; ord < count; ord++) {
// write graph // write graph
offsets[ord] = graphData.getFilePointer() - graphDataOffset; offsets[ord] = graphData.getFilePointer() - graphDataOffset;
NeighborArray neighbors = graph.getNeighbors(ord); NeighborArray neighbors = graph.getNeighbors(0, ord);
int size = neighbors.size(); int size = neighbors.size();
// Destructively modify; it's ok we are discarding it after this // Destructively modify; it's ok we are discarding it after this

View File

@ -35,17 +35,18 @@ public abstract class KnnGraphValues {
* Move the pointer to exactly {@code target}, the id of a node in the graph. After this method * Move the pointer to exactly {@code target}, the id of a node in the graph. After this method
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals. * returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
* *
* @param level level of the graph
* @param target must be a valid node in the graph, ie. &ge; 0 and &lt; {@link * @param target must be a valid node in the graph, ie. &ge; 0 and &lt; {@link
* VectorValues#size()}. * VectorValues#size()}.
*/ */
public abstract void seek(int target) throws IOException; public abstract void seek(int level, int target) throws IOException;
/** Returns the number of nodes in the graph */ /** Returns the number of nodes in the graph */
public abstract int size(); public abstract int size();
/** /**
* Iterates over the neighbor list. It is illegal to call this method after it returns * Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator. * NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
* *
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete. * @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
*/ */
@ -61,7 +62,7 @@ public abstract class KnnGraphValues {
} }
@Override @Override
public void seek(int target) {} public void seek(int level, int target) {}
@Override @Override
public int size() { public int size() {

View File

@ -40,10 +40,10 @@ import org.apache.lucene.util.SparseFixedBitSet;
* <h2>Hyperparameters</h2> * <h2>Hyperparameters</h2>
* *
* <ul> * <ul>
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2012 paper; it controls the * <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the
* number of random entry points to sample. * number of random entry points to sample.
* <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst * <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst
* </code> in the 2016 paper. It is the number of nearest neighbor candidates to track while * </code> in the 2018 paper. It is the number of nearest neighbor candidates to track while
* searching the graph for each newly inserted node. * searching the graph for each newly inserted node.
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls * <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
* how many of the <code>efConst</code> neighbors are connected to the new node * how many of the <code>efConst</code> neighbors are connected to the new node
@ -56,22 +56,28 @@ import org.apache.lucene.util.SparseFixedBitSet;
public final class HnswGraph extends KnnGraphValues { public final class HnswGraph extends KnnGraphValues {
private final int maxConn; private final int maxConn;
// graph is a list of graph levels.
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to // Each level is represented as List<NeighborArray> nodes' connections on this level.
// HnswBuilder, and the // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
// node values are the ordinals of those vectors. // added to HnswBuilder, and the node values are the ordinals of those vectors.
private final List<NeighborArray> graph; private final List<List<NeighborArray>> graph;
// KnnGraphValues iterator members // KnnGraphValues iterator members
private int upto; private int upto;
private NeighborArray cur; private NeighborArray cur;
HnswGraph(int maxConn) { HnswGraph(int maxConn, int numLevels, int levelOfFirstNode) {
graph = new ArrayList<>();
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
this.maxConn = maxConn; this.maxConn = maxConn;
this.graph = new ArrayList<>(numLevels);
for (int i = 0; i < numLevels; i++) {
graph.add(new ArrayList<>());
}
for (int i = 0; i <= levelOfFirstNode; i++) {
// Typically with diversity criteria we see nodes not fully occupied;
// average fanout seems to be about 1/2 maxConn.
// There is some indexing time penalty for under-allocating, but saves RAM
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
}
} }
/** /**
@ -89,6 +95,7 @@ public final class HnswGraph extends KnnGraphValues {
* @param random a source of randomness, used for generating entry points to the graph * @param random a source of randomness, used for generating entry points to the graph
* @return a priority queue holding the closest neighbors found * @return a priority queue holding the closest neighbors found
*/ */
// TODO: implement hierarchical search, currently searches only 0th level
public static NeighborQueue search( public static NeighborQueue search(
float[] query, float[] query,
int topK, int topK,
@ -137,7 +144,7 @@ public final class HnswGraph extends KnnGraphValues {
} }
} }
int topCandidateNode = candidates.pop(); int topCandidateNode = candidates.pop();
graphValues.seek(topCandidateNode); graphValues.seek(0, topCandidateNode);
int friendOrd; int friendOrd;
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) { while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
@ -166,25 +173,36 @@ public final class HnswGraph extends KnnGraphValues {
/** /**
* Returns the {@link NeighborQueue} connected to the given node. * Returns the {@link NeighborQueue} connected to the given node.
* *
* @param level level of the graph
* @param node the node whose neighbors are returned * @param node the node whose neighbors are returned
*/ */
public NeighborArray getNeighbors(int node) { public NeighborArray getNeighbors(int level, int node) {
return graph.get(node); NeighborArray result = graph.get(level).get(node);
assert result != null;
return result;
} }
@Override @Override
public int size() { public int size() {
return graph.size(); return graph.get(0).size(); // all nodes are located on the 0th level
} }
int addNode() { // TODO: optimize RAM usage so not to store references for all nodes for levels > 0
graph.add(new NeighborArray(maxConn + 1)); public void addNode(int level, int node) {
return graph.size() - 1; if (level > 0) {
// Levels above 0th don't contain all nodes,
// so for missing nodes we add null NeighborArray
int nullsToAdd = node - graph.get(level).size();
for (int i = 0; i < nullsToAdd; i++) {
graph.get(level).add(null);
}
}
graph.get(level).add(new NeighborArray(maxConn + 1));
} }
@Override @Override
public void seek(int targetNode) { public void seek(int level, int targetNode) {
cur = getNeighbors(targetNode); cur = getNeighbors(level, targetNode);
upto = -1; upto = -1;
} }

View File

@ -84,7 +84,7 @@ public final class HnswGraphBuilder {
} }
this.maxConn = maxConn; this.maxConn = maxConn;
this.beamWidth = beamWidth; this.beamWidth = beamWidth;
this.hnsw = new HnswGraph(maxConn); this.hnsw = new HnswGraph(maxConn, 1, 0);
bound = BoundsChecker.create(similarityFunction.reversed); bound = BoundsChecker.create(similarityFunction.reversed);
random = new Random(seed); random = new Random(seed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1)); scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
@ -109,7 +109,7 @@ public final class HnswGraphBuilder {
long start = System.nanoTime(), t = start; long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor // start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectors.size(); node++) { for (int node = 1; node < vectors.size(); node++) {
addGraphNode(vectors.vectorValue(node)); addGraphNode(node, vectors.vectorValue(node));
if (node % 10000 == 0) { if (node % 10000 == 0) {
if (infoStream.isEnabled(HNSW_COMPONENT)) { if (infoStream.isEnabled(HNSW_COMPONENT)) {
long now = System.nanoTime(); long now = System.nanoTime();
@ -133,13 +133,14 @@ public final class HnswGraphBuilder {
} }
/** Inserts a doc with vector value to the graph */ /** Inserts a doc with vector value to the graph */
void addGraphNode(float[] value) throws IOException { // TODO: implement hierarchical graph building
void addGraphNode(int node, float[] value) throws IOException {
// We pass 'null' for acceptOrds because there are no deletions while building the graph // We pass 'null' for acceptOrds because there are no deletions while building the graph
NeighborQueue candidates = NeighborQueue candidates =
HnswGraph.search( HnswGraph.search(
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random); value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
int node = hnsw.addNode(); hnsw.addNode(0, node);
/* connect neighbors to the new node, using a diversity heuristic that chooses successive /* connect neighbors to the new node, using a diversity heuristic that chooses successive
* nearest neighbors that are closer to the new node than they are to the previously-selected * nearest neighbors that are closer to the new node than they are to the previously-selected
@ -158,7 +159,7 @@ public final class HnswGraphBuilder {
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method, * is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
* since the node is new and has no prior neighbors). * since the node is new and has no prior neighbors).
*/ */
NeighborArray neighbors = hnsw.getNeighbors(node); NeighborArray neighbors = hnsw.getNeighbors(0, node);
assert neighbors.size() == 0; // new node assert neighbors.size() == 0; // new node
popToScratch(candidates); popToScratch(candidates);
selectDiverse(neighbors, scratch); selectDiverse(neighbors, scratch);
@ -168,7 +169,7 @@ public final class HnswGraphBuilder {
int size = neighbors.size(); int size = neighbors.size();
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i]; int nbr = neighbors.node[i];
NeighborArray nbrNbr = hnsw.getNeighbors(nbr); NeighborArray nbrNbr = hnsw.getNeighbors(0, nbr);
nbrNbr.add(node, neighbors.score[i]); nbrNbr.add(node, neighbors.score[i]);
if (nbrNbr.size() > maxConn) { if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr); diversityUpdate(nbrNbr);

View File

@ -214,7 +214,7 @@ public class TestKnnGraph extends LuceneTestCase {
int[] scratch = new int[maxConn]; int[] scratch = new int[maxConn];
for (int node = 0; node < size; node++) { for (int node = 0; node < size; node++) {
int n, count = 0; int n, count = 0;
values.seek(node); values.seek(0, node);
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) { while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
scratch[count++] = n; scratch[count++] = n;
// graph[node][i++] = n; // graph[node][i++] = n;
@ -352,7 +352,7 @@ public class TestKnnGraph extends LuceneTestCase {
break; break;
} }
int id = Integer.parseInt(reader.document(i).get("id")); int id = Integer.parseInt(reader.document(i).get("id"));
graphValues.seek(graphSize); graphValues.seek(0, graphSize);
// documents with KnnGraphValues have the expected vectors // documents with KnnGraphValues have the expected vectors
float[] scratch = vectorValues.vectorValue(); float[] scratch = vectorValues.vectorValue();
assertArrayEquals( assertArrayEquals(

View File

@ -256,7 +256,7 @@ public class KnnGraphTester {
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0); new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
// start at node 1 // start at node 1
for (int i = 1; i < numDocs; i++) { for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(values.vectorValue(i)); builder.addGraphNode(i, values.vectorValue(i));
System.out.println("\nITERATION " + i); System.out.println("\nITERATION " + i);
dumpGraph(builder.hnsw); dumpGraph(builder.hnsw);
} }
@ -265,7 +265,7 @@ public class KnnGraphTester {
private void dumpGraph(HnswGraph hnsw) { private void dumpGraph(HnswGraph hnsw) {
for (int i = 0; i < hnsw.size(); i++) { for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(i); NeighborArray neighbors = hnsw.getNeighbors(0, i);
System.out.printf(Locale.ROOT, "%5d", i); System.out.printf(Locale.ROOT, "%5d", i);
NeighborArray sorted = new NeighborArray(neighbors.size()); NeighborArray sorted = new NeighborArray(neighbors.size());
for (int j = 0; j < neighbors.size(); j++) { for (int j = 0; j < neighbors.size(); j++) {
@ -297,7 +297,7 @@ public class KnnGraphTester {
int count = 0; int count = 0;
int[] leafHist = new int[numDocs]; int[] leafHist = new int[numDocs];
for (int node = 0; node < numDocs; node++) { for (int node = 0; node < numDocs; node++) {
knnValues.seek(node); knnValues.seek(0, node);
int n = 0; int n = 0;
while (knnValues.nextNeighbor() != NO_MORE_DOCS) { while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
++n; ++n;

View File

@ -150,7 +150,7 @@ public class TestHnswGraph extends LuceneTestCase {
// 45 // 45
assertTrue("sum(result docs)=" + sum, sum < 75); assertTrue("sum(result docs)=" + sum, sum < 75);
for (int i = 0; i < nDoc; i++) { for (int i = 0; i < nDoc; i++) {
NeighborArray neighbors = hnsw.getNeighbors(i); NeighborArray neighbors = hnsw.getNeighbors(0, i);
int[] nodes = neighbors.node; int[] nodes = neighbors.node;
for (int j = 0; j < neighbors.size(); j++) { for (int j = 0; j < neighbors.size(); j++) {
// all neighbors should be valid node ids. // all neighbors should be valid node ids.
@ -252,15 +252,15 @@ public class TestHnswGraph extends LuceneTestCase {
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt()); vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
// node 0 is added by the builder constructor // node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0)); // builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(vectors.vectorValue(1)); builder.addGraphNode(1, vectors.vectorValue(1));
builder.addGraphNode(vectors.vectorValue(2)); builder.addGraphNode(2, vectors.vectorValue(2));
// now every node has tried to attach every other node as a neighbor, but // now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check. // some were excluded based on diversity check.
assertNeighbors(builder.hnsw, 0, 1, 2); assertNeighbors(builder.hnsw, 0, 1, 2);
assertNeighbors(builder.hnsw, 1, 0); assertNeighbors(builder.hnsw, 1, 0);
assertNeighbors(builder.hnsw, 2, 0); assertNeighbors(builder.hnsw, 2, 0);
builder.addGraphNode(vectors.vectorValue(3)); builder.addGraphNode(3, vectors.vectorValue(3));
assertNeighbors(builder.hnsw, 0, 1, 2); assertNeighbors(builder.hnsw, 0, 1, 2);
// we added 3 here // we added 3 here
assertNeighbors(builder.hnsw, 1, 0, 3); assertNeighbors(builder.hnsw, 1, 0, 3);
@ -268,7 +268,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertNeighbors(builder.hnsw, 3, 1); assertNeighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor // supplant an existing neighbor
builder.addGraphNode(vectors.vectorValue(4)); builder.addGraphNode(4, vectors.vectorValue(4));
// 4 is the same distance from 0 that 2 is; we leave the existing node in place // 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertNeighbors(builder.hnsw, 0, 1, 2); assertNeighbors(builder.hnsw, 0, 1, 2);
// 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so // 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
@ -279,7 +279,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertNeighbors(builder.hnsw, 3, 1, 4); assertNeighbors(builder.hnsw, 3, 1, 4);
assertNeighbors(builder.hnsw, 4, 1, 3); assertNeighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(vectors.vectorValue(5)); builder.addGraphNode(5, vectors.vectorValue(5));
assertNeighbors(builder.hnsw, 0, 1, 2); assertNeighbors(builder.hnsw, 0, 1, 2);
assertNeighbors(builder.hnsw, 1, 0, 5); assertNeighbors(builder.hnsw, 1, 0, 5);
assertNeighbors(builder.hnsw, 2, 0); assertNeighbors(builder.hnsw, 2, 0);
@ -291,7 +291,7 @@ public class TestHnswGraph extends LuceneTestCase {
private void assertNeighbors(HnswGraph graph, int node, int... expected) { private void assertNeighbors(HnswGraph graph, int node, int... expected) {
Arrays.sort(expected); Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(node); NeighborArray nn = graph.getNeighbors(0, node);
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size()); int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
Arrays.sort(actual); Arrays.sort(actual);
assertArrayEquals( assertArrayEquals(
@ -439,8 +439,8 @@ public class TestHnswGraph extends LuceneTestCase {
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException { private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
for (int node = 0; node < size; node++) { for (int node = 0; node < size; node++) {
g.seek(node); g.seek(0, node);
h.seek(node); h.seek(0, node);
assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h)); assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
} }
} }