Enhancement 11236 lazy compute similarity score (#12480)

This commit is contained in:
Jack Wang 2023-09-01 11:05:49 -07:00 committed by GitHub
parent d1c3531161
commit 9fd45e3951
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 36 deletions

View File

@ -184,6 +184,8 @@ Optimizations
* GITHUB#12518: Use panama vector API to speed up l2norm calculations (Ben Trent)
* GITHUB##12371: Lazy computation of similarity score during initializeFromGraph (Jack Wang)
Bug Fixes
---------------------

View File

@ -190,8 +190,6 @@ public final class HnswGraphBuilder<T> {
private void initializeFromGraph(
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
assert hnsw.size() == 0;
float[] vectorValue = null;
byte[] binaryValue = null;
for (int level = 0; level < initializerGraph.numLevels(); level++) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
@ -205,27 +203,14 @@ public final class HnswGraphBuilder<T> {
initializedNodes.add(newOrd);
}
switch (this.vectorEncoding) {
case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(newOrd);
case BYTE -> binaryValue = (byte[]) vectors.vectorValue(newOrd);
}
NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
initializerGraph.seek(level, oldOrd);
for (int oldNeighbor = initializerGraph.nextNeighbor();
oldNeighbor != NO_MORE_DOCS;
oldNeighbor = initializerGraph.nextNeighbor()) {
int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
float score =
switch (this.vectorEncoding) {
case FLOAT32 -> this.similarityFunction.compare(
vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor));
case BYTE -> this.similarityFunction.compare(
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
};
// we are not sure whether the previous graph contains
// unchecked nodes, so we have to assume they're all unchecked
newNeighbors.addOutOfOrder(newNeighbor, score);
// we will compute these scores later when we need to pop out the non-diverse nodes
newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
}
}
}
@ -327,7 +312,7 @@ public final class HnswGraphBuilder<T> {
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
if (nbrsOfNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
nbrsOfNbr.removeIndex(indexToRemove);
}
}
@ -409,8 +394,27 @@ public final class HnswGraphBuilder<T> {
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
int[] uncheckedIndexes = neighbors.sort();
private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException {
float[] vectorValue = null;
byte[] binaryValue = null;
switch (this.vectorEncoding) {
case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(nodeOrd);
case BYTE -> binaryValue = (byte[]) vectors.vectorValue(nodeOrd);
}
float[] finalVectorValue = vectorValue;
byte[] finalBinaryValue = binaryValue;
int[] uncheckedIndexes =
neighbors.sort(
nbrOrd -> {
float score =
switch (this.vectorEncoding) {
case FLOAT32 -> this.similarityFunction.compare(
finalVectorValue, (float[]) vectorsCopy.vectorValue(nbrOrd));
case BYTE -> this.similarityFunction.compare(
finalBinaryValue, (byte[]) vectorsCopy.vectorValue(nbrOrd));
};
return score;
});
if (uncheckedIndexes == null) {
// all nodes are checked, we will directly return the most distant one
return neighbors.size() - 1;

View File

@ -17,6 +17,7 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.util.ArrayUtil;
@ -31,7 +32,6 @@ import org.apache.lucene.util.ArrayUtil;
public class NeighborArray {
private final boolean scoresDescOrder;
private int size;
float[] score;
int[] node;
private int sortedNodeSize;
@ -67,14 +67,15 @@ public class NeighborArray {
++sortedNodeSize;
}
/** Add node and score but do not insert as sorted */
/** Add node and newScore but do not insert as sorted */
public void addOutOfOrder(int newNode, float newScore) {
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
node[size] = newNode;
score[size] = newScore;
node[size] = newNode;
size++;
}
@ -85,7 +86,7 @@ public class NeighborArray {
* @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
* already fully sorted
*/
public int[] sort() {
public int[] sort(ScoringFunction scoringFunction) throws IOException {
if (size == sortedNodeSize) {
// all nodes checked and sorted
return null;
@ -94,7 +95,8 @@ public class NeighborArray {
int[] uncheckedIndexes = new int[size - sortedNodeSize];
int count = 0;
while (sortedNodeSize != size) {
uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
uncheckedIndexes[count] =
insertSortedInternal(scoringFunction); // sortedNodeSize is increased inside
for (int i = 0; i < count; i++) {
if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
// the previous inserted nodes has been shifted
@ -108,10 +110,15 @@ public class NeighborArray {
}
/** insert the first unsorted node into its sorted position */
private int insertSortedInternal() {
private int insertSortedInternal(ScoringFunction scoringFunction) throws IOException {
assert sortedNodeSize < size : "Call this method only when there's unsorted node";
int tmpNode = node[sortedNodeSize];
float tmpScore = score[sortedNodeSize];
if (Float.isNaN(tmpScore)) {
tmpScore = scoringFunction.computeScore(tmpNode);
}
int insertionPoint =
scoresDescOrder
? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
@ -127,9 +134,9 @@ public class NeighborArray {
}
/** This method is for test only. */
void insertSorted(int newNode, float newScore) {
void insertSorted(int newNode, float newScore) throws IOException {
addOutOfOrder(newNode, newScore);
insertSortedInternal();
insertSortedInternal(null);
}
public int size() {
@ -197,4 +204,20 @@ public class NeighborArray {
}
return start;
}
/**
* ScoringFunction is a lambda function created in HnswGraphBuilder to allow for lazy computation
* of distance score.
*/
interface ScoringFunction {
/**
* Computes the distance score between the given node ID and the root node of this
* NeighborArray.
*
* @param nodeId The ID of the node for which to compute the distance score.
* @return The distance score as a float value.
* @throws IOException If an I/O error occurs during computation.
*/
float computeScore(int nodeId) throws IOException;
}
}

View File

@ -17,11 +17,12 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestNeighborArray extends LuceneTestCase {
public void testScoresDescOrder() {
public void testScoresDescOrder() throws IOException {
NeighborArray neighbors = new NeighborArray(10, true);
neighbors.addInOrder(0, 1);
neighbors.addInOrder(1, 0.8f);
@ -70,7 +71,7 @@ public class TestNeighborArray extends LuceneTestCase {
assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
}
public void testScoresAscOrder() {
public void testScoresAscOrder() throws IOException {
NeighborArray neighbors = new NeighborArray(10, false);
neighbors.addInOrder(0, 0.1f);
neighbors.addInOrder(1, 0.3f);
@ -119,7 +120,7 @@ public class TestNeighborArray extends LuceneTestCase {
assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
}
public void testSortAsc() {
public void testSortAsc() throws IOException {
NeighborArray neighbors = new NeighborArray(10, false);
neighbors.addOutOfOrder(1, 2);
// we disallow calling addInOrder after addOutOfOrder even if they're actual in order
@ -130,7 +131,7 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors.addOutOfOrder(7, 8);
neighbors.addOutOfOrder(6, 7);
neighbors.addOutOfOrder(4, 5);
int[] unchecked = neighbors.sort();
int[] unchecked = neighbors.sort(null);
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors);
@ -143,13 +144,13 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors2.addOutOfOrder(6, 7);
neighbors2.addOutOfOrder(5, 6);
neighbors2.addOutOfOrder(3, 4);
unchecked = neighbors2.sort();
unchecked = neighbors2.sort(null);
assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2);
assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
}
public void testSortDesc() {
public void testSortDesc() throws IOException {
NeighborArray neighbors = new NeighborArray(10, true);
neighbors.addOutOfOrder(1, 7);
// we disallow calling addInOrder after addOutOfOrder even if they're actual in order
@ -160,7 +161,7 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors.addOutOfOrder(7, 1);
neighbors.addOutOfOrder(6, 2);
neighbors.addOutOfOrder(4, 4);
int[] unchecked = neighbors.sort();
int[] unchecked = neighbors.sort(null);
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
@ -173,12 +174,44 @@ public class TestNeighborArray extends LuceneTestCase {
neighbors2.addOutOfOrder(7, 1);
neighbors2.addOutOfOrder(6, 2);
neighbors2.addOutOfOrder(4, 4);
unchecked = neighbors2.sort();
unchecked = neighbors2.sort(null);
assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2);
}
public void testAddwithScoringFunction() throws IOException {
NeighborArray neighbors = new NeighborArray(10, true);
neighbors.addOutOfOrder(1, Float.NaN);
expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
neighbors.addOutOfOrder(2, Float.NaN);
neighbors.addOutOfOrder(5, Float.NaN);
neighbors.addOutOfOrder(3, Float.NaN);
neighbors.addOutOfOrder(7, Float.NaN);
neighbors.addOutOfOrder(6, Float.NaN);
neighbors.addOutOfOrder(4, Float.NaN);
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 1);
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
}
public void testAddwithScoringFunctionLargeOrd() throws IOException {
NeighborArray neighbors = new NeighborArray(10, true);
neighbors.addOutOfOrder(11, Float.NaN);
expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
neighbors.addOutOfOrder(12, Float.NaN);
neighbors.addOutOfOrder(15, Float.NaN);
neighbors.addOutOfOrder(13, Float.NaN);
neighbors.addOutOfOrder(17, Float.NaN);
neighbors.addOutOfOrder(16, Float.NaN);
neighbors.addOutOfOrder(14, Float.NaN);
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 11);
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors);
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
}
private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
for (int i = 0; i < scores.length; i++) {
assertEquals(scores[i], neighbors.score[i], 0.01f);