From 9fd45e3951d941edbe575d41d900af589bbbe5df Mon Sep 17 00:00:00 2001 From: Jack Wang <45954779+Jackyrie2@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:05:49 -0700 Subject: [PATCH] Enhancement 11236 lazy compute similarity score (#12480) --- lucene/CHANGES.txt | 2 + .../lucene/util/hnsw/HnswGraphBuilder.java | 44 +++++++++-------- .../lucene/util/hnsw/NeighborArray.java | 39 ++++++++++++--- .../lucene/util/hnsw/TestNeighborArray.java | 49 ++++++++++++++++--- 4 files changed, 98 insertions(+), 36 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f8f3b5ac8fb..11e70044b3c 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -184,6 +184,8 @@ Optimizations * GITHUB#12518: Use panama vector API to speed up l2norm calculations (Ben Trent) +* GITHUB##12371: Lazy computation of similarity score during initializeFromGraph (Jack Wang) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index ddeb5dd9535..bcfdde529eb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -190,8 +190,6 @@ public final class HnswGraphBuilder { private void initializeFromGraph( HnswGraph initializerGraph, Map oldToNewOrdinalMap) throws IOException { assert hnsw.size() == 0; - float[] vectorValue = null; - byte[] binaryValue = null; for (int level = 0; level < initializerGraph.numLevels(); level++) { HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); @@ -205,27 +203,14 @@ public final class HnswGraphBuilder { initializedNodes.add(newOrd); } - switch (this.vectorEncoding) { - case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(newOrd); - case BYTE -> binaryValue = (byte[]) vectors.vectorValue(newOrd); - } - NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd); initializerGraph.seek(level, oldOrd); for (int oldNeighbor = initializerGraph.nextNeighbor(); oldNeighbor != NO_MORE_DOCS; oldNeighbor = initializerGraph.nextNeighbor()) { int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor); - float score = - switch (this.vectorEncoding) { - case FLOAT32 -> this.similarityFunction.compare( - vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor)); - case BYTE -> this.similarityFunction.compare( - binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor)); - }; - // we are not sure whether the previous graph contains - // unchecked nodes, so we have to assume they're all unchecked - newNeighbors.addOutOfOrder(newNeighbor, score); + // we will compute these scores later when we need to pop out the non-diverse nodes + newNeighbors.addOutOfOrder(newNeighbor, Float.NaN); } } } @@ -327,7 +312,7 @@ public final class HnswGraphBuilder { NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]); if (nbrsOfNbr.size() > maxConnOnLevel) { - int indexToRemove = findWorstNonDiverse(nbrsOfNbr); + int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr); nbrsOfNbr.removeIndex(indexToRemove); } } @@ -409,8 +394,27 @@ public final class HnswGraphBuilder { * Find first non-diverse neighbour among the list of neighbors starting from the most distant * neighbours */ - private int findWorstNonDiverse(NeighborArray neighbors) throws IOException { - int[] uncheckedIndexes = neighbors.sort(); + private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException { + float[] vectorValue = null; + byte[] binaryValue = null; + switch (this.vectorEncoding) { + case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(nodeOrd); + case BYTE -> binaryValue = (byte[]) vectors.vectorValue(nodeOrd); + } + float[] finalVectorValue = vectorValue; + byte[] finalBinaryValue = binaryValue; + int[] uncheckedIndexes = + neighbors.sort( + nbrOrd -> { + float score = + switch (this.vectorEncoding) { + case FLOAT32 -> this.similarityFunction.compare( + finalVectorValue, (float[]) vectorsCopy.vectorValue(nbrOrd)); + case BYTE -> this.similarityFunction.compare( + finalBinaryValue, (byte[]) vectorsCopy.vectorValue(nbrOrd)); + }; + return score; + }); if (uncheckedIndexes == null) { // all nodes are checked, we will directly return the most distant one return neighbors.size() - 1; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index b44f7da8b8a..d3fa753d32f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -17,6 +17,7 @@ package org.apache.lucene.util.hnsw; +import java.io.IOException; import java.util.Arrays; import org.apache.lucene.util.ArrayUtil; @@ -31,7 +32,6 @@ import org.apache.lucene.util.ArrayUtil; public class NeighborArray { private final boolean scoresDescOrder; private int size; - float[] score; int[] node; private int sortedNodeSize; @@ -67,14 +67,15 @@ public class NeighborArray { ++sortedNodeSize; } - /** Add node and score but do not insert as sorted */ + /** Add node and newScore but do not insert as sorted */ public void addOutOfOrder(int newNode, float newScore) { if (size == node.length) { node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); } - node[size] = newNode; + score[size] = newScore; + node[size] = newNode; size++; } @@ -85,7 +86,7 @@ public class NeighborArray { * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is * already fully sorted */ - public int[] sort() { + public int[] sort(ScoringFunction scoringFunction) throws IOException { if (size == sortedNodeSize) { // all nodes checked and sorted return null; @@ -94,7 +95,8 @@ public class NeighborArray { int[] uncheckedIndexes = new int[size - sortedNodeSize]; int count = 0; while (sortedNodeSize != size) { - uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside + uncheckedIndexes[count] = + insertSortedInternal(scoringFunction); // sortedNodeSize is increased inside for (int i = 0; i < count; i++) { if (uncheckedIndexes[i] >= uncheckedIndexes[count]) { // the previous inserted nodes has been shifted @@ -108,10 +110,15 @@ public class NeighborArray { } /** insert the first unsorted node into its sorted position */ - private int insertSortedInternal() { + private int insertSortedInternal(ScoringFunction scoringFunction) throws IOException { assert sortedNodeSize < size : "Call this method only when there's unsorted node"; int tmpNode = node[sortedNodeSize]; float tmpScore = score[sortedNodeSize]; + + if (Float.isNaN(tmpScore)) { + tmpScore = scoringFunction.computeScore(tmpNode); + } + int insertionPoint = scoresDescOrder ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize) @@ -127,9 +134,9 @@ public class NeighborArray { } /** This method is for test only. */ - void insertSorted(int newNode, float newScore) { + void insertSorted(int newNode, float newScore) throws IOException { addOutOfOrder(newNode, newScore); - insertSortedInternal(); + insertSortedInternal(null); } public int size() { @@ -197,4 +204,20 @@ public class NeighborArray { } return start; } + + /** + * ScoringFunction is a lambda function created in HnswGraphBuilder to allow for lazy computation + * of distance score. + */ + interface ScoringFunction { + /** + * Computes the distance score between the given node ID and the root node of this + * NeighborArray. + * + * @param nodeId The ID of the node for which to compute the distance score. + * @return The distance score as a float value. + * @throws IOException If an I/O error occurs during computation. + */ + float computeScore(int nodeId) throws IOException; + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index c81077aa6da..257c72fb994 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -17,11 +17,12 @@ package org.apache.lucene.util.hnsw; +import java.io.IOException; import org.apache.lucene.tests.util.LuceneTestCase; public class TestNeighborArray extends LuceneTestCase { - public void testScoresDescOrder() { + public void testScoresDescOrder() throws IOException { NeighborArray neighbors = new NeighborArray(10, true); neighbors.addInOrder(0, 1); neighbors.addInOrder(1, 0.8f); @@ -70,7 +71,7 @@ public class TestNeighborArray extends LuceneTestCase { assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors); } - public void testScoresAscOrder() { + public void testScoresAscOrder() throws IOException { NeighborArray neighbors = new NeighborArray(10, false); neighbors.addInOrder(0, 0.1f); neighbors.addInOrder(1, 0.3f); @@ -119,7 +120,7 @@ public class TestNeighborArray extends LuceneTestCase { assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors); } - public void testSortAsc() { + public void testSortAsc() throws IOException { NeighborArray neighbors = new NeighborArray(10, false); neighbors.addOutOfOrder(1, 2); // we disallow calling addInOrder after addOutOfOrder even if they're actual in order @@ -130,7 +131,7 @@ public class TestNeighborArray extends LuceneTestCase { neighbors.addOutOfOrder(7, 8); neighbors.addOutOfOrder(6, 7); neighbors.addOutOfOrder(4, 5); - int[] unchecked = neighbors.sort(); + int[] unchecked = neighbors.sort(null); assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors); @@ -143,13 +144,13 @@ public class TestNeighborArray extends LuceneTestCase { neighbors2.addOutOfOrder(6, 7); neighbors2.addOutOfOrder(5, 6); neighbors2.addOutOfOrder(3, 4); - unchecked = neighbors2.sort(); + unchecked = neighbors2.sort(null); assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked); assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2); assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2); } - public void testSortDesc() { + public void testSortDesc() throws IOException { NeighborArray neighbors = new NeighborArray(10, true); neighbors.addOutOfOrder(1, 7); // we disallow calling addInOrder after addOutOfOrder even if they're actual in order @@ -160,7 +161,7 @@ public class TestNeighborArray extends LuceneTestCase { neighbors.addOutOfOrder(7, 1); neighbors.addOutOfOrder(6, 2); neighbors.addOutOfOrder(4, 4); - int[] unchecked = neighbors.sort(); + int[] unchecked = neighbors.sort(null); assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); @@ -173,12 +174,44 @@ public class TestNeighborArray extends LuceneTestCase { neighbors2.addOutOfOrder(7, 1); neighbors2.addOutOfOrder(6, 2); neighbors2.addOutOfOrder(4, 4); - unchecked = neighbors2.sort(); + unchecked = neighbors2.sort(null); assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked); assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2); assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2); } + public void testAddwithScoringFunction() throws IOException { + NeighborArray neighbors = new NeighborArray(10, true); + neighbors.addOutOfOrder(1, Float.NaN); + expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2)); + neighbors.addOutOfOrder(2, Float.NaN); + neighbors.addOutOfOrder(5, Float.NaN); + neighbors.addOutOfOrder(3, Float.NaN); + neighbors.addOutOfOrder(7, Float.NaN); + neighbors.addOutOfOrder(6, Float.NaN); + neighbors.addOutOfOrder(4, Float.NaN); + int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 1); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); + assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); + } + + public void testAddwithScoringFunctionLargeOrd() throws IOException { + NeighborArray neighbors = new NeighborArray(10, true); + neighbors.addOutOfOrder(11, Float.NaN); + expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2)); + neighbors.addOutOfOrder(12, Float.NaN); + neighbors.addOutOfOrder(15, Float.NaN); + neighbors.addOutOfOrder(13, Float.NaN); + neighbors.addOutOfOrder(17, Float.NaN); + neighbors.addOutOfOrder(16, Float.NaN); + neighbors.addOutOfOrder(14, Float.NaN); + int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 11); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); + assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors); + assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); + } + private void assertScoresEqual(float[] scores, NeighborArray neighbors) { for (int i = 0; i < scores.length; i++) { assertEquals(scores[i], neighbors.score[i], 0.01f);