mirror of https://github.com/apache/lucene.git
LUCENE-10054 Make HnswGraph hierarchical (#250)
Currently HNSW has only a single layer. This is the first part to make it multi-layered. To keep changes small, this PR only adds multiple layers in the HnswGraph class. TODO for following PRs: - modify graph construction and search algorithm for a hierarchical graph. - modify Lucene90HnswVectorsWriter and Lucene90HnswVectorsReader to write and read multiple layers\
This commit is contained in:
parent
46fa09d265
commit
257d256def
|
@ -481,7 +481,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void seek(int targetOrd) throws IOException {
|
public void seek(int level, int targetOrd) throws IOException {
|
||||||
// unsafe; no bounds checking
|
// unsafe; no bounds checking
|
||||||
dataIn.seek(entry.ordOffsets[targetOrd]);
|
dataIn.seek(entry.ordOffsets[targetOrd]);
|
||||||
arcCount = dataIn.readInt();
|
arcCount = dataIn.readInt();
|
||||||
|
|
|
@ -208,11 +208,12 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||||
|
|
||||||
|
// TODO: implement storing of hierarchical graph; for now stores only 0th level
|
||||||
for (int ord = 0; ord < count; ord++) {
|
for (int ord = 0; ord < count; ord++) {
|
||||||
// write graph
|
// write graph
|
||||||
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
|
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
|
||||||
|
|
||||||
NeighborArray neighbors = graph.getNeighbors(ord);
|
NeighborArray neighbors = graph.getNeighbors(0, ord);
|
||||||
int size = neighbors.size();
|
int size = neighbors.size();
|
||||||
|
|
||||||
// Destructively modify; it's ok we are discarding it after this
|
// Destructively modify; it's ok we are discarding it after this
|
||||||
|
|
|
@ -35,17 +35,18 @@ public abstract class KnnGraphValues {
|
||||||
* Move the pointer to exactly {@code target}, the id of a node in the graph. After this method
|
* Move the pointer to exactly {@code target}, the id of a node in the graph. After this method
|
||||||
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
||||||
*
|
*
|
||||||
|
* @param level level of the graph
|
||||||
* @param target must be a valid node in the graph, ie. ≥ 0 and < {@link
|
* @param target must be a valid node in the graph, ie. ≥ 0 and < {@link
|
||||||
* VectorValues#size()}.
|
* VectorValues#size()}.
|
||||||
*/
|
*/
|
||||||
public abstract void seek(int target) throws IOException;
|
public abstract void seek(int level, int target) throws IOException;
|
||||||
|
|
||||||
/** Returns the number of nodes in the graph */
|
/** Returns the number of nodes in the graph */
|
||||||
public abstract int size();
|
public abstract int size();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||||
* NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator.
|
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
||||||
*
|
*
|
||||||
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
||||||
*/
|
*/
|
||||||
|
@ -61,7 +62,7 @@ public abstract class KnnGraphValues {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void seek(int target) {}
|
public void seek(int level, int target) {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int size() {
|
public int size() {
|
||||||
|
|
|
@ -40,10 +40,10 @@ import org.apache.lucene.util.SparseFixedBitSet;
|
||||||
* <h2>Hyperparameters</h2>
|
* <h2>Hyperparameters</h2>
|
||||||
*
|
*
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2012 paper; it controls the
|
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the
|
||||||
* number of random entry points to sample.
|
* number of random entry points to sample.
|
||||||
* <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst
|
* <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst
|
||||||
* </code> in the 2016 paper. It is the number of nearest neighbor candidates to track while
|
* </code> in the 2018 paper. It is the number of nearest neighbor candidates to track while
|
||||||
* searching the graph for each newly inserted node.
|
* searching the graph for each newly inserted node.
|
||||||
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
|
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
|
||||||
* how many of the <code>efConst</code> neighbors are connected to the new node
|
* how many of the <code>efConst</code> neighbors are connected to the new node
|
||||||
|
@ -56,22 +56,28 @@ import org.apache.lucene.util.SparseFixedBitSet;
|
||||||
public final class HnswGraph extends KnnGraphValues {
|
public final class HnswGraph extends KnnGraphValues {
|
||||||
|
|
||||||
private final int maxConn;
|
private final int maxConn;
|
||||||
|
// graph is a list of graph levels.
|
||||||
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to
|
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||||
// HnswBuilder, and the
|
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
||||||
// node values are the ordinals of those vectors.
|
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||||
private final List<NeighborArray> graph;
|
private final List<List<NeighborArray>> graph;
|
||||||
|
|
||||||
// KnnGraphValues iterator members
|
// KnnGraphValues iterator members
|
||||||
private int upto;
|
private int upto;
|
||||||
private NeighborArray cur;
|
private NeighborArray cur;
|
||||||
|
|
||||||
HnswGraph(int maxConn) {
|
HnswGraph(int maxConn, int numLevels, int levelOfFirstNode) {
|
||||||
graph = new ArrayList<>();
|
|
||||||
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
|
|
||||||
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
|
|
||||||
graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
|
|
||||||
this.maxConn = maxConn;
|
this.maxConn = maxConn;
|
||||||
|
this.graph = new ArrayList<>(numLevels);
|
||||||
|
for (int i = 0; i < numLevels; i++) {
|
||||||
|
graph.add(new ArrayList<>());
|
||||||
|
}
|
||||||
|
for (int i = 0; i <= levelOfFirstNode; i++) {
|
||||||
|
// Typically with diversity criteria we see nodes not fully occupied;
|
||||||
|
// average fanout seems to be about 1/2 maxConn.
|
||||||
|
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||||
|
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -89,6 +95,7 @@ public final class HnswGraph extends KnnGraphValues {
|
||||||
* @param random a source of randomness, used for generating entry points to the graph
|
* @param random a source of randomness, used for generating entry points to the graph
|
||||||
* @return a priority queue holding the closest neighbors found
|
* @return a priority queue holding the closest neighbors found
|
||||||
*/
|
*/
|
||||||
|
// TODO: implement hierarchical search, currently searches only 0th level
|
||||||
public static NeighborQueue search(
|
public static NeighborQueue search(
|
||||||
float[] query,
|
float[] query,
|
||||||
int topK,
|
int topK,
|
||||||
|
@ -137,7 +144,7 @@ public final class HnswGraph extends KnnGraphValues {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int topCandidateNode = candidates.pop();
|
int topCandidateNode = candidates.pop();
|
||||||
graphValues.seek(topCandidateNode);
|
graphValues.seek(0, topCandidateNode);
|
||||||
int friendOrd;
|
int friendOrd;
|
||||||
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
||||||
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
|
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
|
||||||
|
@ -166,25 +173,36 @@ public final class HnswGraph extends KnnGraphValues {
|
||||||
/**
|
/**
|
||||||
* Returns the {@link NeighborQueue} connected to the given node.
|
* Returns the {@link NeighborQueue} connected to the given node.
|
||||||
*
|
*
|
||||||
|
* @param level level of the graph
|
||||||
* @param node the node whose neighbors are returned
|
* @param node the node whose neighbors are returned
|
||||||
*/
|
*/
|
||||||
public NeighborArray getNeighbors(int node) {
|
public NeighborArray getNeighbors(int level, int node) {
|
||||||
return graph.get(node);
|
NeighborArray result = graph.get(level).get(node);
|
||||||
|
assert result != null;
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int size() {
|
public int size() {
|
||||||
return graph.size();
|
return graph.get(0).size(); // all nodes are located on the 0th level
|
||||||
}
|
}
|
||||||
|
|
||||||
int addNode() {
|
// TODO: optimize RAM usage so not to store references for all nodes for levels > 0
|
||||||
graph.add(new NeighborArray(maxConn + 1));
|
public void addNode(int level, int node) {
|
||||||
return graph.size() - 1;
|
if (level > 0) {
|
||||||
|
// Levels above 0th don't contain all nodes,
|
||||||
|
// so for missing nodes we add null NeighborArray
|
||||||
|
int nullsToAdd = node - graph.get(level).size();
|
||||||
|
for (int i = 0; i < nullsToAdd; i++) {
|
||||||
|
graph.get(level).add(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
graph.get(level).add(new NeighborArray(maxConn + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void seek(int targetNode) {
|
public void seek(int level, int targetNode) {
|
||||||
cur = getNeighbors(targetNode);
|
cur = getNeighbors(level, targetNode);
|
||||||
upto = -1;
|
upto = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@ public final class HnswGraphBuilder {
|
||||||
}
|
}
|
||||||
this.maxConn = maxConn;
|
this.maxConn = maxConn;
|
||||||
this.beamWidth = beamWidth;
|
this.beamWidth = beamWidth;
|
||||||
this.hnsw = new HnswGraph(maxConn);
|
this.hnsw = new HnswGraph(maxConn, 1, 0);
|
||||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||||
random = new Random(seed);
|
random = new Random(seed);
|
||||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||||
|
@ -109,7 +109,7 @@ public final class HnswGraphBuilder {
|
||||||
long start = System.nanoTime(), t = start;
|
long start = System.nanoTime(), t = start;
|
||||||
// start at node 1! node 0 is added implicitly, in the constructor
|
// start at node 1! node 0 is added implicitly, in the constructor
|
||||||
for (int node = 1; node < vectors.size(); node++) {
|
for (int node = 1; node < vectors.size(); node++) {
|
||||||
addGraphNode(vectors.vectorValue(node));
|
addGraphNode(node, vectors.vectorValue(node));
|
||||||
if (node % 10000 == 0) {
|
if (node % 10000 == 0) {
|
||||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||||
long now = System.nanoTime();
|
long now = System.nanoTime();
|
||||||
|
@ -133,13 +133,14 @@ public final class HnswGraphBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Inserts a doc with vector value to the graph */
|
/** Inserts a doc with vector value to the graph */
|
||||||
void addGraphNode(float[] value) throws IOException {
|
// TODO: implement hierarchical graph building
|
||||||
|
void addGraphNode(int node, float[] value) throws IOException {
|
||||||
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
||||||
NeighborQueue candidates =
|
NeighborQueue candidates =
|
||||||
HnswGraph.search(
|
HnswGraph.search(
|
||||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
||||||
|
|
||||||
int node = hnsw.addNode();
|
hnsw.addNode(0, node);
|
||||||
|
|
||||||
/* connect neighbors to the new node, using a diversity heuristic that chooses successive
|
/* connect neighbors to the new node, using a diversity heuristic that chooses successive
|
||||||
* nearest neighbors that are closer to the new node than they are to the previously-selected
|
* nearest neighbors that are closer to the new node than they are to the previously-selected
|
||||||
|
@ -158,7 +159,7 @@ public final class HnswGraphBuilder {
|
||||||
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
||||||
* since the node is new and has no prior neighbors).
|
* since the node is new and has no prior neighbors).
|
||||||
*/
|
*/
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(node);
|
NeighborArray neighbors = hnsw.getNeighbors(0, node);
|
||||||
assert neighbors.size() == 0; // new node
|
assert neighbors.size() == 0; // new node
|
||||||
popToScratch(candidates);
|
popToScratch(candidates);
|
||||||
selectDiverse(neighbors, scratch);
|
selectDiverse(neighbors, scratch);
|
||||||
|
@ -168,7 +169,7 @@ public final class HnswGraphBuilder {
|
||||||
int size = neighbors.size();
|
int size = neighbors.size();
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
int nbr = neighbors.node[i];
|
int nbr = neighbors.node[i];
|
||||||
NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
|
NeighborArray nbrNbr = hnsw.getNeighbors(0, nbr);
|
||||||
nbrNbr.add(node, neighbors.score[i]);
|
nbrNbr.add(node, neighbors.score[i]);
|
||||||
if (nbrNbr.size() > maxConn) {
|
if (nbrNbr.size() > maxConn) {
|
||||||
diversityUpdate(nbrNbr);
|
diversityUpdate(nbrNbr);
|
||||||
|
|
|
@ -214,7 +214,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
int[] scratch = new int[maxConn];
|
int[] scratch = new int[maxConn];
|
||||||
for (int node = 0; node < size; node++) {
|
for (int node = 0; node < size; node++) {
|
||||||
int n, count = 0;
|
int n, count = 0;
|
||||||
values.seek(node);
|
values.seek(0, node);
|
||||||
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
|
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
|
||||||
scratch[count++] = n;
|
scratch[count++] = n;
|
||||||
// graph[node][i++] = n;
|
// graph[node][i++] = n;
|
||||||
|
@ -352,7 +352,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
int id = Integer.parseInt(reader.document(i).get("id"));
|
int id = Integer.parseInt(reader.document(i).get("id"));
|
||||||
graphValues.seek(graphSize);
|
graphValues.seek(0, graphSize);
|
||||||
// documents with KnnGraphValues have the expected vectors
|
// documents with KnnGraphValues have the expected vectors
|
||||||
float[] scratch = vectorValues.vectorValue();
|
float[] scratch = vectorValues.vectorValue();
|
||||||
assertArrayEquals(
|
assertArrayEquals(
|
||||||
|
|
|
@ -256,7 +256,7 @@ public class KnnGraphTester {
|
||||||
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
|
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
|
||||||
// start at node 1
|
// start at node 1
|
||||||
for (int i = 1; i < numDocs; i++) {
|
for (int i = 1; i < numDocs; i++) {
|
||||||
builder.addGraphNode(values.vectorValue(i));
|
builder.addGraphNode(i, values.vectorValue(i));
|
||||||
System.out.println("\nITERATION " + i);
|
System.out.println("\nITERATION " + i);
|
||||||
dumpGraph(builder.hnsw);
|
dumpGraph(builder.hnsw);
|
||||||
}
|
}
|
||||||
|
@ -265,7 +265,7 @@ public class KnnGraphTester {
|
||||||
|
|
||||||
private void dumpGraph(HnswGraph hnsw) {
|
private void dumpGraph(HnswGraph hnsw) {
|
||||||
for (int i = 0; i < hnsw.size(); i++) {
|
for (int i = 0; i < hnsw.size(); i++) {
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(i);
|
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||||
System.out.printf(Locale.ROOT, "%5d", i);
|
System.out.printf(Locale.ROOT, "%5d", i);
|
||||||
NeighborArray sorted = new NeighborArray(neighbors.size());
|
NeighborArray sorted = new NeighborArray(neighbors.size());
|
||||||
for (int j = 0; j < neighbors.size(); j++) {
|
for (int j = 0; j < neighbors.size(); j++) {
|
||||||
|
@ -297,7 +297,7 @@ public class KnnGraphTester {
|
||||||
int count = 0;
|
int count = 0;
|
||||||
int[] leafHist = new int[numDocs];
|
int[] leafHist = new int[numDocs];
|
||||||
for (int node = 0; node < numDocs; node++) {
|
for (int node = 0; node < numDocs; node++) {
|
||||||
knnValues.seek(node);
|
knnValues.seek(0, node);
|
||||||
int n = 0;
|
int n = 0;
|
||||||
while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
|
while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
|
||||||
++n;
|
++n;
|
||||||
|
|
|
@ -150,7 +150,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
// 45
|
// 45
|
||||||
assertTrue("sum(result docs)=" + sum, sum < 75);
|
assertTrue("sum(result docs)=" + sum, sum < 75);
|
||||||
for (int i = 0; i < nDoc; i++) {
|
for (int i = 0; i < nDoc; i++) {
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(i);
|
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||||
int[] nodes = neighbors.node;
|
int[] nodes = neighbors.node;
|
||||||
for (int j = 0; j < neighbors.size(); j++) {
|
for (int j = 0; j < neighbors.size(); j++) {
|
||||||
// all neighbors should be valid node ids.
|
// all neighbors should be valid node ids.
|
||||||
|
@ -252,15 +252,15 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
|
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
|
||||||
// node 0 is added by the builder constructor
|
// node 0 is added by the builder constructor
|
||||||
// builder.addGraphNode(vectors.vectorValue(0));
|
// builder.addGraphNode(vectors.vectorValue(0));
|
||||||
builder.addGraphNode(vectors.vectorValue(1));
|
builder.addGraphNode(1, vectors.vectorValue(1));
|
||||||
builder.addGraphNode(vectors.vectorValue(2));
|
builder.addGraphNode(2, vectors.vectorValue(2));
|
||||||
// now every node has tried to attach every other node as a neighbor, but
|
// now every node has tried to attach every other node as a neighbor, but
|
||||||
// some were excluded based on diversity check.
|
// some were excluded based on diversity check.
|
||||||
assertNeighbors(builder.hnsw, 0, 1, 2);
|
assertNeighbors(builder.hnsw, 0, 1, 2);
|
||||||
assertNeighbors(builder.hnsw, 1, 0);
|
assertNeighbors(builder.hnsw, 1, 0);
|
||||||
assertNeighbors(builder.hnsw, 2, 0);
|
assertNeighbors(builder.hnsw, 2, 0);
|
||||||
|
|
||||||
builder.addGraphNode(vectors.vectorValue(3));
|
builder.addGraphNode(3, vectors.vectorValue(3));
|
||||||
assertNeighbors(builder.hnsw, 0, 1, 2);
|
assertNeighbors(builder.hnsw, 0, 1, 2);
|
||||||
// we added 3 here
|
// we added 3 here
|
||||||
assertNeighbors(builder.hnsw, 1, 0, 3);
|
assertNeighbors(builder.hnsw, 1, 0, 3);
|
||||||
|
@ -268,7 +268,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
assertNeighbors(builder.hnsw, 3, 1);
|
assertNeighbors(builder.hnsw, 3, 1);
|
||||||
|
|
||||||
// supplant an existing neighbor
|
// supplant an existing neighbor
|
||||||
builder.addGraphNode(vectors.vectorValue(4));
|
builder.addGraphNode(4, vectors.vectorValue(4));
|
||||||
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
||||||
assertNeighbors(builder.hnsw, 0, 1, 2);
|
assertNeighbors(builder.hnsw, 0, 1, 2);
|
||||||
// 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
|
// 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
|
||||||
|
@ -279,7 +279,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
assertNeighbors(builder.hnsw, 3, 1, 4);
|
assertNeighbors(builder.hnsw, 3, 1, 4);
|
||||||
assertNeighbors(builder.hnsw, 4, 1, 3);
|
assertNeighbors(builder.hnsw, 4, 1, 3);
|
||||||
|
|
||||||
builder.addGraphNode(vectors.vectorValue(5));
|
builder.addGraphNode(5, vectors.vectorValue(5));
|
||||||
assertNeighbors(builder.hnsw, 0, 1, 2);
|
assertNeighbors(builder.hnsw, 0, 1, 2);
|
||||||
assertNeighbors(builder.hnsw, 1, 0, 5);
|
assertNeighbors(builder.hnsw, 1, 0, 5);
|
||||||
assertNeighbors(builder.hnsw, 2, 0);
|
assertNeighbors(builder.hnsw, 2, 0);
|
||||||
|
@ -291,7 +291,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
|
|
||||||
private void assertNeighbors(HnswGraph graph, int node, int... expected) {
|
private void assertNeighbors(HnswGraph graph, int node, int... expected) {
|
||||||
Arrays.sort(expected);
|
Arrays.sort(expected);
|
||||||
NeighborArray nn = graph.getNeighbors(node);
|
NeighborArray nn = graph.getNeighbors(0, node);
|
||||||
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
|
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
|
||||||
Arrays.sort(actual);
|
Arrays.sort(actual);
|
||||||
assertArrayEquals(
|
assertArrayEquals(
|
||||||
|
@ -439,8 +439,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
|
|
||||||
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
|
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
|
||||||
for (int node = 0; node < size; node++) {
|
for (int node = 0; node < size; node++) {
|
||||||
g.seek(node);
|
g.seek(0, node);
|
||||||
h.seek(node);
|
h.seek(0, node);
|
||||||
assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
|
assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue