mirror of https://github.com/apache/lucene.git
gh-12627: HnswGraphBuilder connects disconnected HNSW graph components (#13566)
* gh-12627: HnswGraphBuilder connects disconnected HNSW graph components
This commit is contained in:
parent
d26b152117
commit
217828736c
|
@ -463,6 +463,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
|||
// unsafe; no bounds checking
|
||||
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
|
||||
arcCount = dataIn.readVInt();
|
||||
assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount;
|
||||
if (arcCount > 0) {
|
||||
currentNeighborsBuffer[0] = dataIn.readVInt();
|
||||
for (int i = 1; i < arcCount; i++) {
|
||||
|
|
|
@ -612,7 +612,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
OnHeapHnswGraph getGraph() {
|
||||
OnHeapHnswGraph getGraph() throws IOException {
|
||||
assert flatFieldVectorsWriter.isFinished();
|
||||
if (node > 0) {
|
||||
return hnswGraphBuilder.getCompletedGraph();
|
||||
|
|
|
@ -48,5 +48,5 @@ public interface HnswBuilder {
|
|||
* components, re-ordering node ids for better delta compression) may be triggered, so callers
|
||||
* should expect this call to take some time.
|
||||
*/
|
||||
OnHeapHnswGraph getCompletedGraph();
|
||||
OnHeapHnswGraph getCompletedGraph() throws IOException;
|
||||
}
|
||||
|
|
|
@ -91,6 +91,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
|||
});
|
||||
}
|
||||
taskExecutor.invokeAll(futures);
|
||||
finish();
|
||||
frozen = true;
|
||||
return workers[0].getCompletedGraph();
|
||||
}
|
||||
|
@ -109,11 +110,19 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
|||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph getCompletedGraph() {
|
||||
frozen = true;
|
||||
public OnHeapHnswGraph getCompletedGraph() throws IOException {
|
||||
if (frozen == false) {
|
||||
// should already have been called in build(), but just in case
|
||||
finish();
|
||||
frozen = true;
|
||||
}
|
||||
return getGraph();
|
||||
}
|
||||
|
||||
private void finish() throws IOException {
|
||||
workers[0].finish();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph getGraph() {
|
||||
return workers[0].getGraph();
|
||||
|
|
|
@ -18,8 +18,11 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static java.lang.Math.log;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
|
@ -28,6 +31,7 @@ import org.apache.lucene.search.KnnCollector;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.hnsw.HnswUtil.Component;
|
||||
|
||||
/**
|
||||
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
||||
|
@ -137,7 +141,7 @@ public class HnswGraphBuilder implements HnswBuilder {
|
|||
HnswGraphSearcher graphSearcher)
|
||||
throws IOException {
|
||||
if (M <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
throw new IllegalArgumentException("M (max connections) must be positive");
|
||||
}
|
||||
if (beamWidth <= 0) {
|
||||
throw new IllegalArgumentException("beamWidth must be positive");
|
||||
|
@ -173,8 +177,11 @@ public class HnswGraphBuilder implements HnswBuilder {
|
|||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph getCompletedGraph() {
|
||||
frozen = true;
|
||||
public OnHeapHnswGraph getCompletedGraph() throws IOException {
|
||||
if (!frozen) {
|
||||
finish();
|
||||
frozen = true;
|
||||
}
|
||||
return getGraph();
|
||||
}
|
||||
|
||||
|
@ -405,6 +412,93 @@ public class HnswGraphBuilder implements HnswBuilder {
|
|||
return ((int) (-log(randDouble) * ml));
|
||||
}
|
||||
|
||||
void finish() throws IOException {
|
||||
connectComponents();
|
||||
}
|
||||
|
||||
private void connectComponents() throws IOException {
|
||||
long start = System.nanoTime();
|
||||
for (int level = 0; level < hnsw.numLevels(); level++) {
|
||||
if (connectComponents(level) == false) {
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + level);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(
|
||||
HNSW_COMPONENT, "connectComponents " + (System.nanoTime() - start) / 1_000_000 + " ms");
|
||||
}
|
||||
}
|
||||
|
||||
private boolean connectComponents(int level) throws IOException {
|
||||
FixedBitSet notFullyConnected = new FixedBitSet(hnsw.size());
|
||||
int maxConn = M;
|
||||
if (level == 0) {
|
||||
maxConn *= 2;
|
||||
}
|
||||
List<Component> components = HnswUtil.components(hnsw, level, notFullyConnected, maxConn);
|
||||
boolean result = true;
|
||||
if (components.size() > 1) {
|
||||
// connect other components to the largest one
|
||||
Component c0 = components.stream().max(Comparator.comparingInt(Component::size)).get();
|
||||
if (c0.start() == NO_MORE_DOCS) {
|
||||
// the component is already fully connected - no room for new connections
|
||||
return false;
|
||||
}
|
||||
// try for more connections? We only do one since otherwise they may become full
|
||||
// while linking
|
||||
GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(1);
|
||||
int[] eps = new int[1];
|
||||
for (Component c : components) {
|
||||
if (c != c0) {
|
||||
beam.clear();
|
||||
eps[0] = c0.start();
|
||||
RandomVectorScorer scorer = scorerSupplier.scorer(c.start());
|
||||
// find the closest node in the largest component to the lowest-numbered node in this
|
||||
// component that has room to make a connection
|
||||
graphSearcher.searchLevel(beam, scorer, 0, eps, hnsw, notFullyConnected);
|
||||
boolean linked = false;
|
||||
while (beam.size() > 0) {
|
||||
float score = beam.minimumScore();
|
||||
int c0node = beam.popNode();
|
||||
assert notFullyConnected.get(c0node);
|
||||
// link the nodes
|
||||
link(level, c0node, c.start(), score, notFullyConnected);
|
||||
linked = true;
|
||||
}
|
||||
if (!linked) {
|
||||
result = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Try to link two nodes bidirectionally; the forward connection will always be made.
|
||||
// Update notFullyConnected.
|
||||
private void link(int level, int n0, int n1, float score, FixedBitSet notFullyConnected) {
|
||||
NeighborArray nbr0 = hnsw.getNeighbors(level, n0);
|
||||
NeighborArray nbr1 = hnsw.getNeighbors(level, n1);
|
||||
// must subtract 1 here since the nodes array is one larger than the configured
|
||||
// max neighbors (M / 2M).
|
||||
// We should have taken care of this check by searching for not-full nodes
|
||||
int maxConn = nbr0.nodes().length - 1;
|
||||
assert notFullyConnected.get(n0);
|
||||
assert nbr0.size() < maxConn : "node " + n0 + " is full, has " + nbr0.size() + " friends";
|
||||
nbr0.addOutOfOrder(n1, score);
|
||||
if (nbr0.size() == maxConn) {
|
||||
notFullyConnected.clear(n0);
|
||||
}
|
||||
if (nbr1.size() < maxConn) {
|
||||
nbr1.addOutOfOrder(n0, score);
|
||||
if (nbr1.size() == maxConn) {
|
||||
notFullyConnected.clear(n1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A restricted, specialized knnCollector that can be used when building a graph.
|
||||
*
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Deque;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.FilterLeafReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/** Utilities for use in tests involving HNSW graphs */
|
||||
public class HnswUtil {
|
||||
|
||||
// utility class; only has static methods
|
||||
private HnswUtil() {}
|
||||
|
||||
/*
|
||||
For each level, check rooted components from previous level nodes, which are entry
|
||||
points with the goal that each node should be reachable from *some* entry point. For each entry
|
||||
point, compute a spanning tree, recording the nodes in a single shared bitset.
|
||||
|
||||
Also record a bitset marking nodes that are not full to be used when reconnecting in order to
|
||||
limit the search to include non-full nodes only.
|
||||
*/
|
||||
|
||||
/** Returns true if every node on every level is reachable from node 0. */
|
||||
static boolean isRooted(HnswGraph knnValues) throws IOException {
|
||||
for (int level = 0; level < knnValues.numLevels(); level++) {
|
||||
if (components(knnValues, level, null, 0).size() > 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sizes of the distinct graph components on level 0. If the graph is fully-rooted the
|
||||
* list will have one entry. If it is empty, the returned list will be empty.
|
||||
*/
|
||||
static List<Integer> componentSizes(HnswGraph hnsw) throws IOException {
|
||||
return componentSizes(hnsw, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sizes of the distinct graph components on the given level. The forest starting at
|
||||
* the entry points (nodes in the next highest level) is considered as a single component. If the
|
||||
* entire graph is rooted in the entry points, that is every node is reachable from at least one
|
||||
* entry point, the returned list will have a single entry. If the graph is empty, the returned
|
||||
* list will be empty.
|
||||
*/
|
||||
static List<Integer> componentSizes(HnswGraph hnsw, int level) throws IOException {
|
||||
return components(hnsw, level, null, 0).stream().map(Component::size).toList();
|
||||
}
|
||||
|
||||
// Finds orphaned components on the graph level.
|
||||
static List<Component> components(
|
||||
HnswGraph hnsw, int level, FixedBitSet notFullyConnected, int maxConn) throws IOException {
|
||||
List<Component> components = new ArrayList<>();
|
||||
FixedBitSet connectedNodes = new FixedBitSet(hnsw.size());
|
||||
assert hnsw.size() == hnsw.getNodesOnLevel(0).size();
|
||||
int total = 0;
|
||||
if (level >= hnsw.numLevels()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Level " + level + " too large for graph with " + hnsw.numLevels() + " levels");
|
||||
}
|
||||
HnswGraph.NodesIterator entryPoints;
|
||||
// System.out.println("components level=" + level);
|
||||
if (level == hnsw.numLevels() - 1) {
|
||||
entryPoints = new HnswGraph.ArrayNodesIterator(new int[] {hnsw.entryNode()}, 1);
|
||||
} else {
|
||||
entryPoints = hnsw.getNodesOnLevel(level + 1);
|
||||
}
|
||||
while (entryPoints.hasNext()) {
|
||||
int entryPoint = entryPoints.nextInt();
|
||||
Component component =
|
||||
markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, entryPoint);
|
||||
total += component.size();
|
||||
}
|
||||
int entryPoint;
|
||||
if (notFullyConnected != null) {
|
||||
entryPoint = notFullyConnected.nextSetBit(0);
|
||||
} else {
|
||||
entryPoint = connectedNodes.nextSetBit(0);
|
||||
}
|
||||
components.add(new Component(entryPoint, total));
|
||||
if (level == 0) {
|
||||
int nextClear = nextClearBit(connectedNodes, 0);
|
||||
while (nextClear != NO_MORE_DOCS) {
|
||||
Component component =
|
||||
markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, nextClear);
|
||||
assert component.size() > 0;
|
||||
components.add(component);
|
||||
total += component.size();
|
||||
nextClear = nextClearBit(connectedNodes, component.start());
|
||||
}
|
||||
} else {
|
||||
HnswGraph.NodesIterator nodes = hnsw.getNodesOnLevel(level);
|
||||
while (nodes.hasNext()) {
|
||||
int nextClear = nodes.nextInt();
|
||||
if (connectedNodes.get(nextClear)) {
|
||||
continue;
|
||||
}
|
||||
Component component =
|
||||
markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, nextClear);
|
||||
assert component.size() > 0;
|
||||
components.add(component);
|
||||
total += component.size();
|
||||
}
|
||||
}
|
||||
assert total == hnsw.getNodesOnLevel(level).size()
|
||||
: "total="
|
||||
+ total
|
||||
+ " level nodes on level "
|
||||
+ level
|
||||
+ " = "
|
||||
+ hnsw.getNodesOnLevel(level).size();
|
||||
return components;
|
||||
}
|
||||
|
||||
/**
|
||||
* Count the nodes in a rooted component of the graph and set the bits of its nodes in
|
||||
* connectedNodes bitset. Rooted means nodes that can be reached from a root node.
|
||||
*
|
||||
* @param hnswGraph the graph to check
|
||||
* @param level the level of the graph to check
|
||||
* @param connectedNodes a bitset the size of the entire graph with 1's indicating nodes that have
|
||||
* been marked as connected. This method updates the bitset.
|
||||
* @param notFullyConnected a bitset the size of the entire graph. On output, we mark nodes
|
||||
* visited having fewer than maxConn connections. May be null.
|
||||
* @param maxConn the maximum number of connections for any node (aka M).
|
||||
* @param entryPoint a node id to start at
|
||||
*/
|
||||
private static Component markRooted(
|
||||
HnswGraph hnswGraph,
|
||||
int level,
|
||||
FixedBitSet connectedNodes,
|
||||
FixedBitSet notFullyConnected,
|
||||
int maxConn,
|
||||
int entryPoint)
|
||||
throws IOException {
|
||||
// Start at entry point and search all nodes on this level
|
||||
// System.out.println("markRooted level=" + level + " entryPoint=" + entryPoint);
|
||||
Deque<Integer> stack = new ArrayDeque<>();
|
||||
stack.push(entryPoint);
|
||||
int count = 0;
|
||||
while (!stack.isEmpty()) {
|
||||
int node = stack.pop();
|
||||
if (connectedNodes.get(node)) {
|
||||
continue;
|
||||
}
|
||||
count++;
|
||||
connectedNodes.set(node);
|
||||
hnswGraph.seek(level, node);
|
||||
int friendOrd;
|
||||
int friendCount = 0;
|
||||
while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
++friendCount;
|
||||
stack.push(friendOrd);
|
||||
}
|
||||
if (friendCount < maxConn && notFullyConnected != null) {
|
||||
notFullyConnected.set(node);
|
||||
}
|
||||
}
|
||||
return new Component(entryPoint, count);
|
||||
}
|
||||
|
||||
private static int nextClearBit(FixedBitSet bits, int index) {
|
||||
// Does not depend on the ghost bits being clear!
|
||||
long[] barray = bits.getBits();
|
||||
assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length();
|
||||
int i = index >> 6;
|
||||
long word = ~(barray[i] >> index); // skip all the bits to the right of index
|
||||
|
||||
int next = NO_MORE_DOCS;
|
||||
if (word != 0) {
|
||||
next = index + Long.numberOfTrailingZeros(word);
|
||||
} else {
|
||||
while (++i < barray.length) {
|
||||
word = ~barray[i];
|
||||
if (word != 0) {
|
||||
next = (i << 6) + Long.numberOfTrailingZeros(word);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (next >= bits.length()) {
|
||||
return NO_MORE_DOCS;
|
||||
} else {
|
||||
return next;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In graph theory, "connected components" are really defined only for undirected (ie
|
||||
* bidirectional) graphs. Our graphs are directed, because of pruning, but they are *mostly*
|
||||
* undirected. In this case we compute components starting from a single node so what we are
|
||||
* really measuring is whether the graph is a "rooted graph". TODO: measure whether the graph is
|
||||
* "strongly connected" ie there is a path from every node to every other node.
|
||||
*/
|
||||
public static boolean graphIsRooted(IndexReader reader, String vectorField) throws IOException {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader());
|
||||
HnswGraph graph =
|
||||
((HnswGraphProvider)
|
||||
((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())
|
||||
.getFieldReader(vectorField))
|
||||
.getGraph(vectorField);
|
||||
if (isRooted(graph) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* A component (also "connected component") of an undirected graph is a collection of nodes that
|
||||
* are connected by neighbor links: every node in a connected component is reachable from every
|
||||
* other node in the component. See https://en.wikipedia.org/wiki/Component_(graph_theory). Such a
|
||||
* graph is said to be "fully connected" <i>iff</i> it has a single component, or it is empty.
|
||||
*
|
||||
* @param start the lowest-numbered node in the component
|
||||
* @param size the number of nodes in the component
|
||||
*/
|
||||
record Component(int start, int size) {}
|
||||
}
|
|
@ -90,7 +90,10 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public NeighborArray getNeighbors(int level, int node) {
|
||||
assert graph[node][level] != null;
|
||||
assert node < graph.length;
|
||||
assert level < graph[node].length
|
||||
: "level=" + level + ", node has only " + graph[node].length + " levels";
|
||||
assert graph[node][level] != null : "node=" + node + ", level=" + level;
|
||||
return graph[node][level];
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.hnsw.HnswTestUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswUtil;
|
||||
|
||||
@LuceneTestCase.SuppressCodecs("SimpleText")
|
||||
abstract class BaseVectorSimilarityQueryTestCase<
|
||||
|
@ -135,7 +135,7 @@ abstract class BaseVectorSimilarityQueryTestCase<
|
|||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField));
|
||||
assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField));
|
||||
|
||||
// All vectors are above -Infinity
|
||||
Query query1 =
|
||||
|
@ -171,7 +171,7 @@ abstract class BaseVectorSimilarityQueryTestCase<
|
|||
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField));
|
||||
assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField));
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
|
@ -296,7 +296,7 @@ abstract class BaseVectorSimilarityQueryTestCase<
|
|||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField));
|
||||
assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField));
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
public class TestHnswUtil extends LuceneTestCase {
|
||||
|
||||
public void testTreeWithCycle() throws Exception {
|
||||
// test a graph that is a tree - this is rooted from its root node, not rooted
|
||||
// from any other node, and not strongly connected
|
||||
int[][][] nodes = {
|
||||
{
|
||||
{1, 2}, // node 0
|
||||
{3, 4}, // node 1
|
||||
{5, 6}, // node 2
|
||||
{}, {}, {}, {0}
|
||||
}
|
||||
};
|
||||
HnswGraph graph = new MockGraph(nodes);
|
||||
assertTrue(HnswUtil.isRooted(graph));
|
||||
assertEquals(List.of(7), HnswUtil.componentSizes(graph));
|
||||
}
|
||||
|
||||
public void testBackLinking() throws Exception {
|
||||
// test a graph that is a tree - this is rooted from its root node, not rooted
|
||||
// from any other node, and not strongly connected
|
||||
int[][][] nodes = {
|
||||
{
|
||||
{1, 2}, // node 0
|
||||
{3, 4}, // node 1
|
||||
{0}, // node 2
|
||||
{1}, {1}, {1}, {1}
|
||||
}
|
||||
};
|
||||
HnswGraph graph = new MockGraph(nodes);
|
||||
assertFalse(HnswUtil.isRooted(graph));
|
||||
// [ {0, 1, 2, 3, 4}, {5}, {6}
|
||||
assertEquals(List.of(5, 1, 1), HnswUtil.componentSizes(graph));
|
||||
}
|
||||
|
||||
public void testChain() throws Exception {
|
||||
// test a graph that is a chain - this is rooted from every node, thus strongly connected
|
||||
int[][][] nodes = {{{1}, {2}, {3}, {0}}};
|
||||
HnswGraph graph = new MockGraph(nodes);
|
||||
assertTrue(HnswUtil.isRooted(graph));
|
||||
assertEquals(List.of(4), HnswUtil.componentSizes(graph));
|
||||
}
|
||||
|
||||
public void testTwoChains() throws Exception {
|
||||
// test a graph that is two chains
|
||||
int[][][] nodes = {{{2}, {3}, {0}, {1}}};
|
||||
HnswGraph graph = new MockGraph(nodes);
|
||||
assertFalse(HnswUtil.isRooted(graph));
|
||||
assertEquals(List.of(2, 2), HnswUtil.componentSizes(graph));
|
||||
}
|
||||
|
||||
public void testLevels() throws Exception {
|
||||
// test a graph that has three levels
|
||||
int[][][] nodes = {
|
||||
{{1, 2}, {3}, {0}, {0}},
|
||||
{{2}, null, {0}, null},
|
||||
{{}, null, null, null}
|
||||
};
|
||||
HnswGraph graph = new MockGraph(nodes);
|
||||
// System.out.println(graph.toString());
|
||||
assertTrue(HnswUtil.isRooted(graph));
|
||||
assertEquals(List.of(4), HnswUtil.componentSizes(graph));
|
||||
}
|
||||
|
||||
public void testLevelsNotRooted() throws Exception {
|
||||
// test a graph that has two levels with an orphaned node
|
||||
int[][][] nodes = {
|
||||
{{1}, {0}, {0}},
|
||||
{{}, null, null}
|
||||
};
|
||||
HnswGraph graph = new MockGraph(nodes);
|
||||
assertFalse(HnswUtil.isRooted(graph));
|
||||
assertEquals(List.of(2, 1), HnswUtil.componentSizes(graph));
|
||||
}
|
||||
|
||||
public void testRandom() throws Exception {
|
||||
for (int i = 0; i < atLeast(10); i++) {
|
||||
// test on a random directed graph comparing against a brute force algorithm
|
||||
int numNodes = random().nextInt(1, 100);
|
||||
int numLevels = (int) Math.ceil(Math.log(numNodes));
|
||||
int[][][] nodes = new int[numLevels][][];
|
||||
for (int level = numLevels - 1; level >= 0; level--) {
|
||||
nodes[level] = new int[numNodes][];
|
||||
for (int node = 0; node < numNodes; node++) {
|
||||
if (level > 0) {
|
||||
if ((level == numLevels - 1 && node > 0)
|
||||
|| (level < numLevels - 1 && nodes[level + 1][node] == null)) {
|
||||
if (random().nextFloat() > Math.pow(Math.E, -level)) {
|
||||
// skip some nodes, more on higher levels while ensuring every node present on a
|
||||
// given level is present on all lower levels. Also ensure node 0 is always present.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
int numNbrs = random().nextInt((numNodes + 7) / 8);
|
||||
if (level == 0) {
|
||||
numNbrs *= 2;
|
||||
}
|
||||
nodes[level][node] = new int[numNbrs];
|
||||
for (int nbr = 0; nbr < numNbrs; nbr++) {
|
||||
while (true) {
|
||||
int randomNbr = random().nextInt(numNodes);
|
||||
if (nodes[level][randomNbr] != null) {
|
||||
// allow self-linking; this doesn't arise in HNSW but it's valid more generally
|
||||
nodes[level][node][nbr] = randomNbr;
|
||||
break;
|
||||
}
|
||||
// nbr not on this level, try again
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MockGraph graph = new MockGraph(nodes);
|
||||
/*
|
||||
System.out.println("iter " + i);
|
||||
System.out.print(graph.toString());
|
||||
*/
|
||||
assertEquals(isRooted(nodes), HnswUtil.isRooted(graph));
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isRooted(int[][][] nodes) {
|
||||
for (int level = nodes.length - 1; level >= 0; level--) {
|
||||
if (isRooted(nodes, level) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean isRooted(int[][][] nodes, int level) {
|
||||
// check that the graph is rooted in the union of the entry nodes' trees
|
||||
// System.out.println("isRooted level=" + level);
|
||||
int entryPointLevel;
|
||||
if (level == nodes.length - 1) {
|
||||
entryPointLevel = level;
|
||||
} else {
|
||||
entryPointLevel = level + 1;
|
||||
}
|
||||
FixedBitSet connected = new FixedBitSet(nodes[level].length);
|
||||
int count = 0;
|
||||
for (int entryPoint = 0; entryPoint < nodes[entryPointLevel].length; entryPoint++) {
|
||||
if (nodes[entryPointLevel][entryPoint] == null) {
|
||||
// use nodes present on next higher level (or this level if top level) as entry points
|
||||
continue;
|
||||
}
|
||||
// System.out.println(" isRooted level=" + level + " entryPoint=" + entryPoint);
|
||||
ArrayDeque<Integer> stack = new ArrayDeque<>();
|
||||
stack.push(entryPoint);
|
||||
while (!stack.isEmpty()) {
|
||||
int node = stack.pop();
|
||||
if (connected.get(node)) {
|
||||
continue;
|
||||
}
|
||||
// System.out.println(" connected node=" + node);
|
||||
connected.set(node);
|
||||
count++;
|
||||
for (int nbr : nodes[level][node]) {
|
||||
stack.push(nbr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return count == levelSize(nodes[level]);
|
||||
}
|
||||
|
||||
static int levelSize(int[][] nodes) {
|
||||
int count = 0;
|
||||
for (int[] node : nodes) {
|
||||
if (node != null) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/** Empty graph value */
|
||||
static class MockGraph extends HnswGraph {
|
||||
|
||||
private final int[][][] nodes;
|
||||
|
||||
private int currentLevel;
|
||||
private int currentNode;
|
||||
private int currentNeighbor;
|
||||
|
||||
MockGraph(int[][][] nodes) {
|
||||
this.nodes = nodes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
if (currentNeighbor >= nodes[currentLevel][currentNode].length) {
|
||||
return NO_MORE_DOCS;
|
||||
} else {
|
||||
return nodes[currentLevel][currentNode][currentNeighbor++];
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int target) {
|
||||
assert level >= 0 && level < nodes.length;
|
||||
assert target >= 0 && target < nodes[level].length
|
||||
: "target out of range: "
|
||||
+ target
|
||||
+ " for level "
|
||||
+ level
|
||||
+ "; should be less than "
|
||||
+ nodes[level].length;
|
||||
assert nodes[level][target] != null : "target " + target + " not on level " + level;
|
||||
currentLevel = level;
|
||||
currentNode = target;
|
||||
currentNeighbor = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return nodes[0].length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return nodes.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder buf = new StringBuilder();
|
||||
for (int level = nodes.length - 1; level >= 0; level--) {
|
||||
buf.append("\nLEVEL ").append(level).append("\n");
|
||||
for (int node = 0; node < nodes[level].length; node++) {
|
||||
if (nodes[level][node] != null) {
|
||||
buf.append(" ")
|
||||
.append(node)
|
||||
.append(':')
|
||||
.append(Arrays.toString(nodes[level][node]))
|
||||
.append("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
|
||||
int count = 0;
|
||||
for (int i = 0; i < nodes[level].length; i++) {
|
||||
if (nodes[level][i] != null) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
final int finalCount = count;
|
||||
|
||||
return new NodesIterator(finalCount) {
|
||||
int cur = -1;
|
||||
int curCount = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return curCount < finalCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
while (curCount < finalCount) {
|
||||
if (nodes[level][++cur] != null) {
|
||||
curCount++;
|
||||
return cur;
|
||||
}
|
||||
}
|
||||
throw new IllegalStateException("exhausted");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int consume(int[] dest) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -49,7 +49,6 @@ module org.apache.lucene.test_framework {
|
|||
exports org.apache.lucene.tests.store;
|
||||
exports org.apache.lucene.tests.util.automaton;
|
||||
exports org.apache.lucene.tests.util.fst;
|
||||
exports org.apache.lucene.tests.util.hnsw;
|
||||
exports org.apache.lucene.tests.util;
|
||||
|
||||
provides org.apache.lucene.codecs.Codec with
|
||||
|
|
|
@ -1,132 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.tests.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Deque;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.FilterLeafReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
||||
/** Utilities for use in tests involving HNSW graphs */
|
||||
public class HnswTestUtil {
|
||||
|
||||
/**
|
||||
* Returns true iff level 0 of the graph is fully connected - that is every node is reachable from
|
||||
* any entry point.
|
||||
*/
|
||||
public static boolean isFullyConnected(HnswGraph knnValues) throws IOException {
|
||||
return componentSizes(knnValues).size() < 2;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sizes of the distinct graph components on level 0. If the graph is fully-connected
|
||||
* there will only be a single component. If the graph is empty, the returned list will be empty.
|
||||
*/
|
||||
public static List<Integer> componentSizes(HnswGraph hnsw) throws IOException {
|
||||
List<Integer> sizes = new ArrayList<>();
|
||||
FixedBitSet connectedNodes = new FixedBitSet(hnsw.size());
|
||||
assert hnsw.size() == hnsw.getNodesOnLevel(0).size();
|
||||
int total = 0;
|
||||
while (total < connectedNodes.length()) {
|
||||
int componentSize = traverseConnectedNodes(hnsw, connectedNodes);
|
||||
assert componentSize > 0;
|
||||
sizes.add(componentSize);
|
||||
total += componentSize;
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
// count the nodes in a connected component of the graph and set the bits of its nodes in
|
||||
// connectedNodes bitset
|
||||
private static int traverseConnectedNodes(HnswGraph hnswGraph, FixedBitSet connectedNodes)
|
||||
throws IOException {
|
||||
// Start at entry point and search all nodes on this level
|
||||
int entryPoint = nextClearBit(connectedNodes, 0);
|
||||
if (entryPoint == NO_MORE_DOCS) {
|
||||
return 0;
|
||||
}
|
||||
Deque<Integer> stack = new ArrayDeque<>();
|
||||
stack.push(entryPoint);
|
||||
int count = 0;
|
||||
while (!stack.isEmpty()) {
|
||||
int node = stack.pop();
|
||||
if (connectedNodes.get(node)) {
|
||||
continue;
|
||||
}
|
||||
count++;
|
||||
connectedNodes.set(node);
|
||||
hnswGraph.seek(0, node);
|
||||
int friendOrd;
|
||||
while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
stack.push(friendOrd);
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
private static int nextClearBit(FixedBitSet bits, int index) {
|
||||
// Does not depend on the ghost bits being clear!
|
||||
long[] barray = bits.getBits();
|
||||
assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length();
|
||||
int i = index >> 6;
|
||||
long word = ~(barray[i] >> index); // skip all the bits to the right of index
|
||||
|
||||
if (word != 0) {
|
||||
return index + Long.numberOfTrailingZeros(word);
|
||||
}
|
||||
|
||||
while (++i < barray.length) {
|
||||
word = ~barray[i];
|
||||
if (word != 0) {
|
||||
int next = (i << 6) + Long.numberOfTrailingZeros(word);
|
||||
if (next >= bits.length()) {
|
||||
return NO_MORE_DOCS;
|
||||
} else {
|
||||
return next;
|
||||
}
|
||||
}
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
public static boolean graphIsConnected(IndexReader reader, String vectorField)
|
||||
throws IOException {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader());
|
||||
HnswGraph graph =
|
||||
((HnswGraphProvider)
|
||||
((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())
|
||||
.getFieldReader(vectorField))
|
||||
.getGraph(vectorField);
|
||||
if (isFullyConnected(graph) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue