LUCENE-10040 Correct TestHnswGraph.testSearchWithAcceptOrds (#277)

If we set numSeed = 10, this test fails sometimes  because it may mark
expected results docs (from 0 to 9) as deleted which don't end up
being retrieved, resulting in a low recall

- set numSeed to 10 to ensure 10 results are returned
- add startIndex paramenter to createRandomAcceptOrds that allows
  documents before startIndex to be NOT deleted
- use startIndex equal to 10 for createRandomAcceptOrds

Relates to #239
This commit is contained in:
Mayya Sharipova 2021-09-06 06:56:15 -04:00 committed by GitHub
parent 4df8d641ac
commit bc161e6dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -136,25 +136,29 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraph.search(
new float[] {1, 0},
10,
5,
10,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
null,
random());
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
for (int node : nn.nodes()) {
for (int node : nodes) {
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
// 45
// We expect to get approximately 100% recall;
// the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);
for (int i = 0; i < nDoc; i++) {
NeighborArray neighbors = hnsw.getNeighbors(i);
int[] nodes = neighbors.node;
int[] nnodes = neighbors.node;
for (int j = 0; j < neighbors.size(); j++) {
// all neighbors should be valid node ids.
assertTrue(nodes[j] < nDoc);
assertTrue(nnodes[j] < nDoc);
}
}
}
@ -167,24 +171,27 @@ public class TestHnswGraph extends LuceneTestCase {
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = createRandomAcceptOrds(vectors.size);
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn =
HnswGraph.search(
new float[] {1, 0},
10,
5,
10,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds,
random());
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
for (int node : nn.nodes()) {
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
// 45
// We expect to get approximately 100% recall;
// the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);
}
@ -311,7 +318,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
@ -492,10 +499,18 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
/** Generate a random bitset where each entry has a 2/3 probability of being set. */
private static Bits createRandomAcceptOrds(int length) {
/**
* Generate a random bitset where before startIndex all bits are set, and after startIndex each
* entry has a 2/3 probability of being set.
*/
private static Bits createRandomAcceptOrds(int startIndex, int length) {
FixedBitSet bits = new FixedBitSet(length);
for (int i = 0; i < bits.length(); i++) {
// all bits are set before startIndex
for (int i = 0; i < startIndex; i++) {
bits.set(i);
}
// after startIndex, bits are set with 2/3 probability
for (int i = startIndex; i < bits.length(); i++) {
if (random().nextFloat() < 0.667f) {
bits.set(i);
}