LUCENE-9610: fix TestKnnGraph.testMerge

This commit is contained in:
Michael Sokolov 2020-11-14 10:31:09 -05:00
parent 52f581e351
commit 09f78e2927
3 changed files with 48 additions and 6 deletions

View File

@ -104,6 +104,12 @@ public final class HnswGraphBuilder {
if (searchStrategy == VectorValues.SearchStrategy.NONE) { if (searchStrategy == VectorValues.SearchStrategy.NONE) {
throw new IllegalStateException("No distance function"); throw new IllegalStateException("No distance function");
} }
if (maxConn <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
}
this.maxConn = maxConn; this.maxConn = maxConn;
this.beamWidth = beamWidth; this.beamWidth = beamWidth;
boundedVectors = new BoundedVectorValues(vectorValues); boundedVectors = new BoundedVectorValues(vectorValues);

View File

@ -29,6 +29,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
@ -47,9 +48,20 @@ public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector"; private static final String KNN_GRAPH_FIELD = "vector";
private static int maxConn;
@Before @Before
public void setup() { public void setup() {
randSeed = random().nextLong(); randSeed = random().nextLong();
if (random().nextBoolean()) {
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(1000) + 1;
}
}
@After
public void cleanup() {
HnswGraphBuilder.DEFAULT_MAX_CONN = maxConn;
} }
/** /**
@ -114,6 +126,18 @@ public class TestKnnGraph extends LuceneTestCase {
} }
} }
private void dumpGraph(KnnGraphValues values, int size) throws IOException {
for (int node = 0; node < size; node++) {
int n;
System.out.print("" + node + ":");
values.seek(node);
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
System.out.print(" " + n);
}
System.out.println();
}
}
// TODO: testSorted // TODO: testSorted
// TODO: testDeletions // TODO: testDeletions
@ -223,7 +247,6 @@ public class TestKnnGraph extends LuceneTestCase {
int[][] graph = new int[reader.maxDoc()][]; int[][] graph = new int[reader.maxDoc()][];
boolean foundOrphan= false; boolean foundOrphan= false;
int graphSize = 0; int graphSize = 0;
int node = -1;
for (int i = 0; i < reader.maxDoc(); i++) { for (int i = 0; i < reader.maxDoc(); i++) {
int nextDocWithVectors = vectorValues.advance(i); int nextDocWithVectors = vectorValues.advance(i);
//System.out.println("advanced to " + nextDocWithVectors); //System.out.println("advanced to " + nextDocWithVectors);
@ -236,7 +259,7 @@ public class TestKnnGraph extends LuceneTestCase {
break; break;
} }
int id = Integer.parseInt(reader.document(i).get("id")); int id = Integer.parseInt(reader.document(i).get("id"));
graphValues.seek(++node); graphValues.seek(graphSize);
// documents with KnnGraphValues have the expected vectors // documents with KnnGraphValues have the expected vectors
float[] scratch = vectorValues.vectorValue(); float[] scratch = vectorValues.vectorValue();
assertArrayEquals("vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), assertArrayEquals("vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
@ -267,10 +290,14 @@ public class TestKnnGraph extends LuceneTestCase {
} else { } else {
assertTrue("Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1); assertTrue("Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
} }
if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) {
// assert that the graph in each leaf is connected and undirected (ie links are reciprocated) // assert that the graph in each leaf is connected and undirected (ie links are reciprocated)
// assertReciprocal(graph); assertReciprocal(graph);
assertConnected(graph); assertConnected(graph);
} else {
// assert that max-connections was respected
assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN); assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN);
}
totalGraphDocs += graphSize; totalGraphDocs += graphSize;
} }
} }
@ -333,6 +360,9 @@ public class TestKnnGraph extends LuceneTestCase {
} }
} }
} }
for (int i = 0; i < count; i++) {
assertTrue("Attempted to walk entire graph but never visited " + i, visited.contains(i));
}
// we visited each node exactly once // we visited each node exactly once
assertEquals("Attempted to walk entire graph but only visited " + visited.size(), count, visited.size()); assertEquals("Attempted to walk entire graph but only visited " + visited.size(), count, visited.size());
} }

View File

@ -456,4 +456,10 @@ public class TestHnsw extends LuceneTestCase {
assertTrue(min.check(f + 1e-5f)); // delta is zero initially assertTrue(min.check(f + 1e-5f)); // delta is zero initially
} }
public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
}
} }