diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index b308a86fa91..d1161798117 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -104,6 +104,12 @@ public final class HnswGraphBuilder { if (searchStrategy == VectorValues.SearchStrategy.NONE) { throw new IllegalStateException("No distance function"); } + if (maxConn <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } this.maxConn = maxConn; this.beamWidth = beamWidth; boundedVectors = new BoundedVectorValues(vectorValues); diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index d7f8fc91248..034b4e3f0d3 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -29,6 +29,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.hnsw.HnswGraphBuilder; +import org.junit.After; import org.junit.Before; import java.io.IOException; @@ -47,9 +48,20 @@ public class TestKnnGraph extends LuceneTestCase { private static final String KNN_GRAPH_FIELD = "vector"; + private static int maxConn; + @Before public void setup() { randSeed = random().nextLong(); + if (random().nextBoolean()) { + maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN; + HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(1000) + 1; + } + } + + @After + public void cleanup() { + HnswGraphBuilder.DEFAULT_MAX_CONN = maxConn; } /** @@ -114,6 +126,18 @@ public class TestKnnGraph extends LuceneTestCase { } } + private void dumpGraph(KnnGraphValues values, int size) throws IOException { + for (int node = 0; node < size; node++) { + int n; + System.out.print("" + node + ":"); + values.seek(node); + while ((n = values.nextNeighbor()) != NO_MORE_DOCS) { + System.out.print(" " + n); + } + System.out.println(); + } + } + // TODO: testSorted // TODO: testDeletions @@ -223,7 +247,6 @@ public class TestKnnGraph extends LuceneTestCase { int[][] graph = new int[reader.maxDoc()][]; boolean foundOrphan= false; int graphSize = 0; - int node = -1; for (int i = 0; i < reader.maxDoc(); i++) { int nextDocWithVectors = vectorValues.advance(i); //System.out.println("advanced to " + nextDocWithVectors); @@ -236,7 +259,7 @@ public class TestKnnGraph extends LuceneTestCase { break; } int id = Integer.parseInt(reader.document(i).get("id")); - graphValues.seek(++node); + graphValues.seek(graphSize); // documents with KnnGraphValues have the expected vectors float[] scratch = vectorValues.vectorValue(); assertArrayEquals("vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), @@ -267,10 +290,14 @@ public class TestKnnGraph extends LuceneTestCase { } else { assertTrue("Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1); } - // assert that the graph in each leaf is connected and undirected (ie links are reciprocated) - // assertReciprocal(graph); - assertConnected(graph); - assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN); + if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) { + // assert that the graph in each leaf is connected and undirected (ie links are reciprocated) + assertReciprocal(graph); + assertConnected(graph); + } else { + // assert that max-connections was respected + assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN); + } totalGraphDocs += graphSize; } } @@ -333,6 +360,9 @@ public class TestKnnGraph extends LuceneTestCase { } } } + for (int i = 0; i < count; i++) { + assertTrue("Attempted to walk entire graph but never visited " + i, visited.contains(i)); + } // we visited each node exactly once assertEquals("Attempted to walk entire graph but only visited " + visited.size(), count, visited.size()); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java index 5a1b7325b7b..8f50a1dac59 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java @@ -456,4 +456,10 @@ public class TestHnsw extends LuceneTestCase { assertTrue(min.check(f + 1e-5f)); // delta is zero initially } + public void testHnswGraphBuilderInvalid() { + expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0)); + expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0)); + expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0)); + } + }