mirror of https://github.com/apache/lucene.git
Diversity check bugfix (#11781)
* Fixes bug in HNSW diversity checks introduced in LUCENE-10577
This commit is contained in:
parent
e69c48b8d9
commit
07af358f90
|
@ -109,7 +109,7 @@ public final class HnswGraphBuilder<T> {
|
|||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
// normalization factor for level generation; currently not configurable
|
||||
this.ml = 1 / Math.log(1.0 * M);
|
||||
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
|
||||
|
@ -316,49 +316,49 @@ public final class HnswGraphBuilder<T> {
|
|||
*/
|
||||
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
||||
for (int i = neighbors.size() - 1; i > 0; i--) {
|
||||
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
|
||||
if (isWorstNonDiverse(i, neighbors)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return neighbors.size() - 1;
|
||||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
|
||||
private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
|
||||
throws IOException {
|
||||
int candidateNode = neighbors.node[candidateIndex];
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isWorstNonDiverse(
|
||||
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
|
||||
case FLOAT32 -> isWorstNonDiverse(
|
||||
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
candidateIndex, vectors.vectorValue(candidateNode), neighbors);
|
||||
};
|
||||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
|
||||
throws IOException {
|
||||
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||
int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
|
||||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
// candidate node is too similar to node i given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
|
||||
throws IOException {
|
||||
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||
int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
|
||||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
|
||||
// candidate node is too similar to node i given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
||||
|
|
|
@ -504,6 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
unitVector2d(0.9),
|
||||
unitVector2d(0.8),
|
||||
unitVector2d(0.77),
|
||||
unitVector2d(0.6)
|
||||
};
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (float[] v : values) {
|
||||
|
@ -555,6 +556,78 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
|
||||
}
|
||||
|
||||
public void testDiversityFallback() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
// Some test cases can't be exercised in two dimensions;
|
||||
// in particular if a new neighbor displaces an existing neighbor
|
||||
// by being closer to the target, yet none of the existing neighbors is closer to the new vector
|
||||
// than to the target -- ie they all remain diverse, so we simply drop the farthest one.
|
||||
float[][] values = {
|
||||
{0, 0, 0},
|
||||
{0, 10, 0},
|
||||
{0, 0, 20},
|
||||
{10, 0, 0},
|
||||
{0, 4, 0}
|
||||
};
|
||||
MockVectorValues vectors = new MockVectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
// 2 is closer to 0 than 1, so it is excluded as non-diverse
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
||||
// 1 is closer to 0 than 2, so it is excluded as non-diverse
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
|
||||
builder.addGraphNode(3, vectorsCopy);
|
||||
// this is one case we are testing; 2 has been displaced by 3
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
assertLevel0Neighbors(builder.hnsw, 3, 0);
|
||||
}
|
||||
|
||||
public void testDiversity3d() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
|
||||
float[][] values = {
|
||||
{0, 0, 0},
|
||||
{0, 10, 0},
|
||||
{0, 0, 20},
|
||||
{0, 9, 0}
|
||||
};
|
||||
MockVectorValues vectors = new MockVectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
// 2 is closer to 0 than 1, so it is excluded as non-diverse
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
||||
// 1 is closer to 0 than 2, so it is excluded as non-diverse
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
|
||||
builder.addGraphNode(3, vectorsCopy);
|
||||
// this is one case we are testing; 1 has been displaced by 3
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
assertLevel0Neighbors(builder.hnsw, 3, 0, 1);
|
||||
}
|
||||
|
||||
private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
|
||||
Arrays.sort(expected);
|
||||
NeighborArray nn = graph.getNeighbors(0, node);
|
||||
|
|
Loading…
Reference in New Issue