Diversity check bugfix (#11781)

* Fixes bug in HNSW diversity checks introduced in LUCENE-10577
This commit is contained in:
Michael Sokolov 2022-09-19 11:48:59 -04:00 committed by GitHub
parent e69c48b8d9
commit 07af358f90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 21 deletions

View File

@ -109,7 +109,7 @@ public final class HnswGraphBuilder<T> {
this.M = M;
this.beamWidth = beamWidth;
// normalization factor for level generation; currently not configurable
this.ml = 1 / Math.log(1.0 * M);
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
@ -316,49 +316,49 @@ public final class HnswGraphBuilder<T> {
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i > 0; i--) {
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
if (isWorstNonDiverse(i, neighbors)) {
return i;
}
}
return neighbors.size() - 1;
}
private boolean isWorstNonDiverse(
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
candidateIndex, vectors.vectorValue(candidateNode), neighbors);
};
}
private boolean isWorstNonDiverse(
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
return true;
}
}
return true;
return false;
}
private boolean isWorstNonDiverse(
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
return true;
}
}
return true;
return false;
}
private static int getRandomGraphLevel(double ml, SplittableRandom random) {

View File

@ -504,6 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
unitVector2d(0.6)
};
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] v : values) {
@ -555,6 +556,78 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}
public void testDiversityFallback() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// Some test cases can't be exercised in two dimensions;
// in particular if a new neighbor displaces an existing neighbor
// by being closer to the target, yet none of the existing neighbors is closer to the new vector
// than to the target -- ie they all remain diverse, so we simply drop the farthest one.
float[][] values = {
{0, 0, 0},
{0, 10, 0},
{0, 0, 20},
{10, 0, 0},
{0, 4, 0}
};
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectorsCopy);
// this is one case we are testing; 2 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0);
assertLevel0Neighbors(builder.hnsw, 3, 0);
}
public void testDiversity3d() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
float[][] values = {
{0, 0, 0},
{0, 10, 0},
{0, 0, 20},
{0, 9, 0}
};
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectorsCopy);
// this is one case we are testing; 1 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
assertLevel0Neighbors(builder.hnsw, 2, 0);
assertLevel0Neighbors(builder.hnsw, 3, 0, 1);
}
private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(0, node);