gh-12627: HnswGraphBuilder connects disconnected HNSW graph components (#13566)

* gh-12627: HnswGraphBuilder connects disconnected HNSW graph components
This commit is contained in:
Michael Sokolov 2024-08-08 14:41:52 -04:00 committed by GitHub
parent d26b152117
commit 217828736c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 679 additions and 145 deletions

View File

@ -463,6 +463,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount;
if (arcCount > 0) {
currentNeighborsBuffer[0] = dataIn.readVInt();
for (int i = 1; i < arcCount; i++) {

View File

@ -612,7 +612,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
throw new UnsupportedOperationException();
}
OnHeapHnswGraph getGraph() {
OnHeapHnswGraph getGraph() throws IOException {
assert flatFieldVectorsWriter.isFinished();
if (node > 0) {
return hnswGraphBuilder.getCompletedGraph();

View File

@ -48,5 +48,5 @@ public interface HnswBuilder {
* components, re-ordering node ids for better delta compression) may be triggered, so callers
* should expect this call to take some time.
*/
OnHeapHnswGraph getCompletedGraph();
OnHeapHnswGraph getCompletedGraph() throws IOException;
}

View File

@ -91,6 +91,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
});
}
taskExecutor.invokeAll(futures);
finish();
frozen = true;
return workers[0].getCompletedGraph();
}
@ -109,11 +110,19 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
}
@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
public OnHeapHnswGraph getCompletedGraph() throws IOException {
if (frozen == false) {
// should already have been called in build(), but just in case
finish();
frozen = true;
}
return getGraph();
}
private void finish() throws IOException {
workers[0].finish();
}
@Override
public OnHeapHnswGraph getGraph() {
return workers[0].getGraph();

View File

@ -18,8 +18,11 @@
package org.apache.lucene.util.hnsw;
import static java.lang.Math.log;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
@ -28,6 +31,7 @@ import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswUtil.Component;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
@ -137,7 +141,7 @@ public class HnswGraphBuilder implements HnswBuilder {
HnswGraphSearcher graphSearcher)
throws IOException {
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
throw new IllegalArgumentException("M (max connections) must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
@ -173,8 +177,11 @@ public class HnswGraphBuilder implements HnswBuilder {
}
@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
public OnHeapHnswGraph getCompletedGraph() throws IOException {
if (!frozen) {
finish();
frozen = true;
}
return getGraph();
}
@ -405,6 +412,93 @@ public class HnswGraphBuilder implements HnswBuilder {
return ((int) (-log(randDouble) * ml));
}
void finish() throws IOException {
connectComponents();
}
private void connectComponents() throws IOException {
long start = System.nanoTime();
for (int level = 0; level < hnsw.numLevels(); level++) {
if (connectComponents(level) == false) {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + level);
}
}
}
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(
HNSW_COMPONENT, "connectComponents " + (System.nanoTime() - start) / 1_000_000 + " ms");
}
}
private boolean connectComponents(int level) throws IOException {
FixedBitSet notFullyConnected = new FixedBitSet(hnsw.size());
int maxConn = M;
if (level == 0) {
maxConn *= 2;
}
List<Component> components = HnswUtil.components(hnsw, level, notFullyConnected, maxConn);
boolean result = true;
if (components.size() > 1) {
// connect other components to the largest one
Component c0 = components.stream().max(Comparator.comparingInt(Component::size)).get();
if (c0.start() == NO_MORE_DOCS) {
// the component is already fully connected - no room for new connections
return false;
}
// try for more connections? We only do one since otherwise they may become full
// while linking
GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(1);
int[] eps = new int[1];
for (Component c : components) {
if (c != c0) {
beam.clear();
eps[0] = c0.start();
RandomVectorScorer scorer = scorerSupplier.scorer(c.start());
// find the closest node in the largest component to the lowest-numbered node in this
// component that has room to make a connection
graphSearcher.searchLevel(beam, scorer, 0, eps, hnsw, notFullyConnected);
boolean linked = false;
while (beam.size() > 0) {
float score = beam.minimumScore();
int c0node = beam.popNode();
assert notFullyConnected.get(c0node);
// link the nodes
link(level, c0node, c.start(), score, notFullyConnected);
linked = true;
}
if (!linked) {
result = false;
}
}
}
}
return result;
}
// Try to link two nodes bidirectionally; the forward connection will always be made.
// Update notFullyConnected.
private void link(int level, int n0, int n1, float score, FixedBitSet notFullyConnected) {
NeighborArray nbr0 = hnsw.getNeighbors(level, n0);
NeighborArray nbr1 = hnsw.getNeighbors(level, n1);
// must subtract 1 here since the nodes array is one larger than the configured
// max neighbors (M / 2M).
// We should have taken care of this check by searching for not-full nodes
int maxConn = nbr0.nodes().length - 1;
assert notFullyConnected.get(n0);
assert nbr0.size() < maxConn : "node " + n0 + " is full, has " + nbr0.size() + " friends";
nbr0.addOutOfOrder(n1, score);
if (nbr0.size() == maxConn) {
notFullyConnected.clear(n0);
}
if (nbr1.size() < maxConn) {
nbr1.addOutOfOrder(n0, score);
if (nbr1.size() == maxConn) {
notFullyConnected.clear(n1);
}
}
}
/**
* A restricted, specialized knnCollector that can be used when building a graph.
*

View File

@ -0,0 +1,248 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.FixedBitSet;
/** Utilities for use in tests involving HNSW graphs */
public class HnswUtil {
// utility class; only has static methods
private HnswUtil() {}
/*
For each level, check rooted components from previous level nodes, which are entry
points with the goal that each node should be reachable from *some* entry point. For each entry
point, compute a spanning tree, recording the nodes in a single shared bitset.
Also record a bitset marking nodes that are not full to be used when reconnecting in order to
limit the search to include non-full nodes only.
*/
/** Returns true if every node on every level is reachable from node 0. */
static boolean isRooted(HnswGraph knnValues) throws IOException {
for (int level = 0; level < knnValues.numLevels(); level++) {
if (components(knnValues, level, null, 0).size() > 1) {
return false;
}
}
return true;
}
/**
* Returns the sizes of the distinct graph components on level 0. If the graph is fully-rooted the
* list will have one entry. If it is empty, the returned list will be empty.
*/
static List<Integer> componentSizes(HnswGraph hnsw) throws IOException {
return componentSizes(hnsw, 0);
}
/**
* Returns the sizes of the distinct graph components on the given level. The forest starting at
* the entry points (nodes in the next highest level) is considered as a single component. If the
* entire graph is rooted in the entry points, that is every node is reachable from at least one
* entry point, the returned list will have a single entry. If the graph is empty, the returned
* list will be empty.
*/
static List<Integer> componentSizes(HnswGraph hnsw, int level) throws IOException {
return components(hnsw, level, null, 0).stream().map(Component::size).toList();
}
// Finds orphaned components on the graph level.
static List<Component> components(
HnswGraph hnsw, int level, FixedBitSet notFullyConnected, int maxConn) throws IOException {
List<Component> components = new ArrayList<>();
FixedBitSet connectedNodes = new FixedBitSet(hnsw.size());
assert hnsw.size() == hnsw.getNodesOnLevel(0).size();
int total = 0;
if (level >= hnsw.numLevels()) {
throw new IllegalArgumentException(
"Level " + level + " too large for graph with " + hnsw.numLevels() + " levels");
}
HnswGraph.NodesIterator entryPoints;
// System.out.println("components level=" + level);
if (level == hnsw.numLevels() - 1) {
entryPoints = new HnswGraph.ArrayNodesIterator(new int[] {hnsw.entryNode()}, 1);
} else {
entryPoints = hnsw.getNodesOnLevel(level + 1);
}
while (entryPoints.hasNext()) {
int entryPoint = entryPoints.nextInt();
Component component =
markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, entryPoint);
total += component.size();
}
int entryPoint;
if (notFullyConnected != null) {
entryPoint = notFullyConnected.nextSetBit(0);
} else {
entryPoint = connectedNodes.nextSetBit(0);
}
components.add(new Component(entryPoint, total));
if (level == 0) {
int nextClear = nextClearBit(connectedNodes, 0);
while (nextClear != NO_MORE_DOCS) {
Component component =
markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, nextClear);
assert component.size() > 0;
components.add(component);
total += component.size();
nextClear = nextClearBit(connectedNodes, component.start());
}
} else {
HnswGraph.NodesIterator nodes = hnsw.getNodesOnLevel(level);
while (nodes.hasNext()) {
int nextClear = nodes.nextInt();
if (connectedNodes.get(nextClear)) {
continue;
}
Component component =
markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, nextClear);
assert component.size() > 0;
components.add(component);
total += component.size();
}
}
assert total == hnsw.getNodesOnLevel(level).size()
: "total="
+ total
+ " level nodes on level "
+ level
+ " = "
+ hnsw.getNodesOnLevel(level).size();
return components;
}
/**
* Count the nodes in a rooted component of the graph and set the bits of its nodes in
* connectedNodes bitset. Rooted means nodes that can be reached from a root node.
*
* @param hnswGraph the graph to check
* @param level the level of the graph to check
* @param connectedNodes a bitset the size of the entire graph with 1's indicating nodes that have
* been marked as connected. This method updates the bitset.
* @param notFullyConnected a bitset the size of the entire graph. On output, we mark nodes
* visited having fewer than maxConn connections. May be null.
* @param maxConn the maximum number of connections for any node (aka M).
* @param entryPoint a node id to start at
*/
private static Component markRooted(
HnswGraph hnswGraph,
int level,
FixedBitSet connectedNodes,
FixedBitSet notFullyConnected,
int maxConn,
int entryPoint)
throws IOException {
// Start at entry point and search all nodes on this level
// System.out.println("markRooted level=" + level + " entryPoint=" + entryPoint);
Deque<Integer> stack = new ArrayDeque<>();
stack.push(entryPoint);
int count = 0;
while (!stack.isEmpty()) {
int node = stack.pop();
if (connectedNodes.get(node)) {
continue;
}
count++;
connectedNodes.set(node);
hnswGraph.seek(level, node);
int friendOrd;
int friendCount = 0;
while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
++friendCount;
stack.push(friendOrd);
}
if (friendCount < maxConn && notFullyConnected != null) {
notFullyConnected.set(node);
}
}
return new Component(entryPoint, count);
}
private static int nextClearBit(FixedBitSet bits, int index) {
// Does not depend on the ghost bits being clear!
long[] barray = bits.getBits();
assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length();
int i = index >> 6;
long word = ~(barray[i] >> index); // skip all the bits to the right of index
int next = NO_MORE_DOCS;
if (word != 0) {
next = index + Long.numberOfTrailingZeros(word);
} else {
while (++i < barray.length) {
word = ~barray[i];
if (word != 0) {
next = (i << 6) + Long.numberOfTrailingZeros(word);
break;
}
}
}
if (next >= bits.length()) {
return NO_MORE_DOCS;
} else {
return next;
}
}
/**
* In graph theory, "connected components" are really defined only for undirected (ie
* bidirectional) graphs. Our graphs are directed, because of pruning, but they are *mostly*
* undirected. In this case we compute components starting from a single node so what we are
* really measuring is whether the graph is a "rooted graph". TODO: measure whether the graph is
* "strongly connected" ie there is a path from every node to every other node.
*/
public static boolean graphIsRooted(IndexReader reader, String vectorField) throws IOException {
for (LeafReaderContext ctx : reader.leaves()) {
CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader());
HnswGraph graph =
((HnswGraphProvider)
((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())
.getFieldReader(vectorField))
.getGraph(vectorField);
if (isRooted(graph) == false) {
return false;
}
}
return true;
}
/**
* A component (also "connected component") of an undirected graph is a collection of nodes that
* are connected by neighbor links: every node in a connected component is reachable from every
* other node in the component. See https://en.wikipedia.org/wiki/Component_(graph_theory). Such a
* graph is said to be "fully connected" <i>iff</i> it has a single component, or it is empty.
*
* @param start the lowest-numbered node in the component
* @param size the number of nodes in the component
*/
record Component(int start, int size) {}
}

View File

@ -90,7 +90,10 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
*/
public NeighborArray getNeighbors(int level, int node) {
assert graph[node][level] != null;
assert node < graph.length;
assert level < graph[node].length
: "level=" + level + ", node has only " + graph[node].length + " levels";
assert graph[node][level] != null : "node=" + node + ", level=" + level;
return graph[node][level];
}

View File

@ -41,7 +41,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.hnsw.HnswTestUtil;
import org.apache.lucene.util.hnsw.HnswUtil;
@LuceneTestCase.SuppressCodecs("SimpleText")
abstract class BaseVectorSimilarityQueryTestCase<
@ -135,7 +135,7 @@ abstract class BaseVectorSimilarityQueryTestCase<
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField));
assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField));
// All vectors are above -Infinity
Query query1 =
@ -171,7 +171,7 @@ abstract class BaseVectorSimilarityQueryTestCase<
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField));
assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField));
IndexSearcher searcher = newSearcher(reader);
Query query =
@ -296,7 +296,7 @@ abstract class BaseVectorSimilarityQueryTestCase<
w.commit();
try (IndexReader reader = DirectoryReader.open(indexStore)) {
assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField));
assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField));
IndexSearcher searcher = newSearcher(reader);
Query query =

View File

@ -0,0 +1,312 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.FixedBitSet;
public class TestHnswUtil extends LuceneTestCase {
public void testTreeWithCycle() throws Exception {
// test a graph that is a tree - this is rooted from its root node, not rooted
// from any other node, and not strongly connected
int[][][] nodes = {
{
{1, 2}, // node 0
{3, 4}, // node 1
{5, 6}, // node 2
{}, {}, {}, {0}
}
};
HnswGraph graph = new MockGraph(nodes);
assertTrue(HnswUtil.isRooted(graph));
assertEquals(List.of(7), HnswUtil.componentSizes(graph));
}
public void testBackLinking() throws Exception {
// test a graph that is a tree - this is rooted from its root node, not rooted
// from any other node, and not strongly connected
int[][][] nodes = {
{
{1, 2}, // node 0
{3, 4}, // node 1
{0}, // node 2
{1}, {1}, {1}, {1}
}
};
HnswGraph graph = new MockGraph(nodes);
assertFalse(HnswUtil.isRooted(graph));
// [ {0, 1, 2, 3, 4}, {5}, {6}
assertEquals(List.of(5, 1, 1), HnswUtil.componentSizes(graph));
}
public void testChain() throws Exception {
// test a graph that is a chain - this is rooted from every node, thus strongly connected
int[][][] nodes = {{{1}, {2}, {3}, {0}}};
HnswGraph graph = new MockGraph(nodes);
assertTrue(HnswUtil.isRooted(graph));
assertEquals(List.of(4), HnswUtil.componentSizes(graph));
}
public void testTwoChains() throws Exception {
// test a graph that is two chains
int[][][] nodes = {{{2}, {3}, {0}, {1}}};
HnswGraph graph = new MockGraph(nodes);
assertFalse(HnswUtil.isRooted(graph));
assertEquals(List.of(2, 2), HnswUtil.componentSizes(graph));
}
public void testLevels() throws Exception {
// test a graph that has three levels
int[][][] nodes = {
{{1, 2}, {3}, {0}, {0}},
{{2}, null, {0}, null},
{{}, null, null, null}
};
HnswGraph graph = new MockGraph(nodes);
// System.out.println(graph.toString());
assertTrue(HnswUtil.isRooted(graph));
assertEquals(List.of(4), HnswUtil.componentSizes(graph));
}
public void testLevelsNotRooted() throws Exception {
// test a graph that has two levels with an orphaned node
int[][][] nodes = {
{{1}, {0}, {0}},
{{}, null, null}
};
HnswGraph graph = new MockGraph(nodes);
assertFalse(HnswUtil.isRooted(graph));
assertEquals(List.of(2, 1), HnswUtil.componentSizes(graph));
}
public void testRandom() throws Exception {
for (int i = 0; i < atLeast(10); i++) {
// test on a random directed graph comparing against a brute force algorithm
int numNodes = random().nextInt(1, 100);
int numLevels = (int) Math.ceil(Math.log(numNodes));
int[][][] nodes = new int[numLevels][][];
for (int level = numLevels - 1; level >= 0; level--) {
nodes[level] = new int[numNodes][];
for (int node = 0; node < numNodes; node++) {
if (level > 0) {
if ((level == numLevels - 1 && node > 0)
|| (level < numLevels - 1 && nodes[level + 1][node] == null)) {
if (random().nextFloat() > Math.pow(Math.E, -level)) {
// skip some nodes, more on higher levels while ensuring every node present on a
// given level is present on all lower levels. Also ensure node 0 is always present.
continue;
}
}
}
int numNbrs = random().nextInt((numNodes + 7) / 8);
if (level == 0) {
numNbrs *= 2;
}
nodes[level][node] = new int[numNbrs];
for (int nbr = 0; nbr < numNbrs; nbr++) {
while (true) {
int randomNbr = random().nextInt(numNodes);
if (nodes[level][randomNbr] != null) {
// allow self-linking; this doesn't arise in HNSW but it's valid more generally
nodes[level][node][nbr] = randomNbr;
break;
}
// nbr not on this level, try again
}
}
}
}
MockGraph graph = new MockGraph(nodes);
/*
System.out.println("iter " + i);
System.out.print(graph.toString());
*/
assertEquals(isRooted(nodes), HnswUtil.isRooted(graph));
}
}
private boolean isRooted(int[][][] nodes) {
for (int level = nodes.length - 1; level >= 0; level--) {
if (isRooted(nodes, level) == false) {
return false;
}
}
return true;
}
private boolean isRooted(int[][][] nodes, int level) {
// check that the graph is rooted in the union of the entry nodes' trees
// System.out.println("isRooted level=" + level);
int entryPointLevel;
if (level == nodes.length - 1) {
entryPointLevel = level;
} else {
entryPointLevel = level + 1;
}
FixedBitSet connected = new FixedBitSet(nodes[level].length);
int count = 0;
for (int entryPoint = 0; entryPoint < nodes[entryPointLevel].length; entryPoint++) {
if (nodes[entryPointLevel][entryPoint] == null) {
// use nodes present on next higher level (or this level if top level) as entry points
continue;
}
// System.out.println(" isRooted level=" + level + " entryPoint=" + entryPoint);
ArrayDeque<Integer> stack = new ArrayDeque<>();
stack.push(entryPoint);
while (!stack.isEmpty()) {
int node = stack.pop();
if (connected.get(node)) {
continue;
}
// System.out.println(" connected node=" + node);
connected.set(node);
count++;
for (int nbr : nodes[level][node]) {
stack.push(nbr);
}
}
}
return count == levelSize(nodes[level]);
}
static int levelSize(int[][] nodes) {
int count = 0;
for (int[] node : nodes) {
if (node != null) {
++count;
}
}
return count;
}
/** Empty graph value */
static class MockGraph extends HnswGraph {
private final int[][][] nodes;
private int currentLevel;
private int currentNode;
private int currentNeighbor;
MockGraph(int[][][] nodes) {
this.nodes = nodes;
}
@Override
public int nextNeighbor() {
if (currentNeighbor >= nodes[currentLevel][currentNode].length) {
return NO_MORE_DOCS;
} else {
return nodes[currentLevel][currentNode][currentNeighbor++];
}
}
@Override
public void seek(int level, int target) {
assert level >= 0 && level < nodes.length;
assert target >= 0 && target < nodes[level].length
: "target out of range: "
+ target
+ " for level "
+ level
+ "; should be less than "
+ nodes[level].length;
assert nodes[level][target] != null : "target " + target + " not on level " + level;
currentLevel = level;
currentNode = target;
currentNeighbor = 0;
}
@Override
public int size() {
return nodes[0].length;
}
@Override
public int numLevels() {
return nodes.length;
}
@Override
public int entryNode() {
return 0;
}
@Override
public String toString() {
StringBuilder buf = new StringBuilder();
for (int level = nodes.length - 1; level >= 0; level--) {
buf.append("\nLEVEL ").append(level).append("\n");
for (int node = 0; node < nodes[level].length; node++) {
if (nodes[level][node] != null) {
buf.append(" ")
.append(node)
.append(':')
.append(Arrays.toString(nodes[level][node]))
.append("\n");
}
}
}
return buf.toString();
}
@Override
public NodesIterator getNodesOnLevel(int level) {
int count = 0;
for (int i = 0; i < nodes[level].length; i++) {
if (nodes[level][i] != null) {
count++;
}
}
final int finalCount = count;
return new NodesIterator(finalCount) {
int cur = -1;
int curCount = 0;
@Override
public boolean hasNext() {
return curCount < finalCount;
}
@Override
public int nextInt() {
while (curCount < finalCount) {
if (nodes[level][++cur] != null) {
curCount++;
return cur;
}
}
throw new IllegalStateException("exhausted");
}
@Override
public int consume(int[] dest) {
throw new UnsupportedOperationException();
}
};
}
}
}

View File

@ -49,7 +49,6 @@ module org.apache.lucene.test_framework {
exports org.apache.lucene.tests.store;
exports org.apache.lucene.tests.util.automaton;
exports org.apache.lucene.tests.util.fst;
exports org.apache.lucene.tests.util.hnsw;
exports org.apache.lucene.tests.util;
provides org.apache.lucene.codecs.Codec with

View File

@ -1,132 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.tests.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.hnsw.HnswGraph;
/** Utilities for use in tests involving HNSW graphs */
public class HnswTestUtil {
/**
* Returns true iff level 0 of the graph is fully connected - that is every node is reachable from
* any entry point.
*/
public static boolean isFullyConnected(HnswGraph knnValues) throws IOException {
return componentSizes(knnValues).size() < 2;
}
/**
* Returns the sizes of the distinct graph components on level 0. If the graph is fully-connected
* there will only be a single component. If the graph is empty, the returned list will be empty.
*/
public static List<Integer> componentSizes(HnswGraph hnsw) throws IOException {
List<Integer> sizes = new ArrayList<>();
FixedBitSet connectedNodes = new FixedBitSet(hnsw.size());
assert hnsw.size() == hnsw.getNodesOnLevel(0).size();
int total = 0;
while (total < connectedNodes.length()) {
int componentSize = traverseConnectedNodes(hnsw, connectedNodes);
assert componentSize > 0;
sizes.add(componentSize);
total += componentSize;
}
return sizes;
}
// count the nodes in a connected component of the graph and set the bits of its nodes in
// connectedNodes bitset
private static int traverseConnectedNodes(HnswGraph hnswGraph, FixedBitSet connectedNodes)
throws IOException {
// Start at entry point and search all nodes on this level
int entryPoint = nextClearBit(connectedNodes, 0);
if (entryPoint == NO_MORE_DOCS) {
return 0;
}
Deque<Integer> stack = new ArrayDeque<>();
stack.push(entryPoint);
int count = 0;
while (!stack.isEmpty()) {
int node = stack.pop();
if (connectedNodes.get(node)) {
continue;
}
count++;
connectedNodes.set(node);
hnswGraph.seek(0, node);
int friendOrd;
while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
stack.push(friendOrd);
}
}
return count;
}
private static int nextClearBit(FixedBitSet bits, int index) {
// Does not depend on the ghost bits being clear!
long[] barray = bits.getBits();
assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length();
int i = index >> 6;
long word = ~(barray[i] >> index); // skip all the bits to the right of index
if (word != 0) {
return index + Long.numberOfTrailingZeros(word);
}
while (++i < barray.length) {
word = ~barray[i];
if (word != 0) {
int next = (i << 6) + Long.numberOfTrailingZeros(word);
if (next >= bits.length()) {
return NO_MORE_DOCS;
} else {
return next;
}
}
}
return NO_MORE_DOCS;
}
public static boolean graphIsConnected(IndexReader reader, String vectorField)
throws IOException {
for (LeafReaderContext ctx : reader.leaves()) {
CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader());
HnswGraph graph =
((HnswGraphProvider)
((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())
.getFieldReader(vectorField))
.getGraph(vectorField);
if (isFullyConnected(graph) == false) {
return false;
}
}
return true;
}
}