diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index d5ed059b0c2..cc8e330ffc8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -109,7 +109,7 @@ public final class HnswGraphBuilder { this.M = M; this.beamWidth = beamWidth; // normalization factor for level generation; currently not configurable - this.ml = 1 / Math.log(1.0 * M); + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); int levelOfFirstNode = getRandomGraphLevel(ml, random); this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode); @@ -316,49 +316,49 @@ public final class HnswGraphBuilder { */ private int findWorstNonDiverse(NeighborArray neighbors) throws IOException { for (int i = neighbors.size() - 1; i > 0; i--) { - if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) { + if (isWorstNonDiverse(i, neighbors)) { return i; } } return neighbors.size() - 1; } - private boolean isWorstNonDiverse( - int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException { + private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors) + throws IOException { + int candidateNode = neighbors.node[candidateIndex]; return switch (vectorEncoding) { - case BYTE -> isWorstNonDiverse( - candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity); + case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors); case FLOAT32 -> isWorstNonDiverse( - candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity); + candidateIndex, vectors.vectorValue(candidateNode), neighbors); }; } private boolean isWorstNonDiverse( - int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity) - throws IOException { - for (int i = candidateIndex - 1; i > -0; i--) { + int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException { + float minAcceptedSimilarity = neighbors.score[candidateIndex]; + for (int i = candidateIndex - 1; i >= 0; i--) { float neighborSimilarity = - similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i])); - // node i is too similar to node j given its score relative to the base node + similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i])); + // candidate node is too similar to node i given its score relative to the base node if (neighborSimilarity >= minAcceptedSimilarity) { - return false; + return true; } } - return true; + return false; } private boolean isWorstNonDiverse( - int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity) - throws IOException { - for (int i = candidateIndex - 1; i > -0; i--) { + int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException { + float minAcceptedSimilarity = neighbors.score[candidateIndex]; + for (int i = candidateIndex - 1; i >= 0; i--) { float neighborSimilarity = - similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i])); - // node i is too similar to node j given its score relative to the base node + similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i])); + // candidate node is too similar to node i given its score relative to the base node if (neighborSimilarity >= minAcceptedSimilarity) { - return false; + return true; } } - return true; + return false; } private static int getRandomGraphLevel(double ml, SplittableRandom random) { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index 7852e706157..16d1996820a 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -504,6 +504,7 @@ public class TestHnswGraph extends LuceneTestCase { unitVector2d(0.9), unitVector2d(0.8), unitVector2d(0.77), + unitVector2d(0.6) }; if (vectorEncoding == VectorEncoding.BYTE) { for (float[] v : values) { @@ -555,6 +556,78 @@ public class TestHnswGraph extends LuceneTestCase { assertLevel0Neighbors(builder.hnsw, 5, 1, 4); } + public void testDiversityFallback() throws IOException { + vectorEncoding = randomVectorEncoding(); + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + // Some test cases can't be exercised in two dimensions; + // in particular if a new neighbor displaces an existing neighbor + // by being closer to the target, yet none of the existing neighbors is closer to the new vector + // than to the target -- ie they all remain diverse, so we simply drop the farthest one. + float[][] values = { + {0, 0, 0}, + {0, 10, 0}, + {0, 0, 20}, + {10, 0, 0}, + {0, 4, 0} + }; + MockVectorValues vectors = new MockVectorValues(values); + // First add nodes until everybody gets a full neighbor list + HnswGraphBuilder builder = + HnswGraphBuilder.create( + vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt()); + // node 0 is added by the builder constructor + // builder.addGraphNode(vectors.vectorValue(0)); + RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(1, vectorsCopy); + builder.addGraphNode(2, vectorsCopy); + assertLevel0Neighbors(builder.hnsw, 0, 1, 2); + // 2 is closer to 0 than 1, so it is excluded as non-diverse + assertLevel0Neighbors(builder.hnsw, 1, 0); + // 1 is closer to 0 than 2, so it is excluded as non-diverse + assertLevel0Neighbors(builder.hnsw, 2, 0); + + builder.addGraphNode(3, vectorsCopy); + // this is one case we are testing; 2 has been displaced by 3 + assertLevel0Neighbors(builder.hnsw, 0, 1, 3); + assertLevel0Neighbors(builder.hnsw, 1, 0); + assertLevel0Neighbors(builder.hnsw, 2, 0); + assertLevel0Neighbors(builder.hnsw, 3, 0); + } + + public void testDiversity3d() throws IOException { + vectorEncoding = randomVectorEncoding(); + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + // test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives + float[][] values = { + {0, 0, 0}, + {0, 10, 0}, + {0, 0, 20}, + {0, 9, 0} + }; + MockVectorValues vectors = new MockVectorValues(values); + // First add nodes until everybody gets a full neighbor list + HnswGraphBuilder builder = + HnswGraphBuilder.create( + vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt()); + // node 0 is added by the builder constructor + // builder.addGraphNode(vectors.vectorValue(0)); + RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(1, vectorsCopy); + builder.addGraphNode(2, vectorsCopy); + assertLevel0Neighbors(builder.hnsw, 0, 1, 2); + // 2 is closer to 0 than 1, so it is excluded as non-diverse + assertLevel0Neighbors(builder.hnsw, 1, 0); + // 1 is closer to 0 than 2, so it is excluded as non-diverse + assertLevel0Neighbors(builder.hnsw, 2, 0); + + builder.addGraphNode(3, vectorsCopy); + // this is one case we are testing; 1 has been displaced by 3 + assertLevel0Neighbors(builder.hnsw, 0, 2, 3); + assertLevel0Neighbors(builder.hnsw, 1, 0, 3); + assertLevel0Neighbors(builder.hnsw, 2, 0); + assertLevel0Neighbors(builder.hnsw, 3, 0, 1); + } + private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) { Arrays.sort(expected); NeighborArray nn = graph.getNeighbors(0, node);