mirror of https://github.com/apache/lucene.git
Enhancement 11236 lazy compute similarity score (#12480)
This commit is contained in:
parent
d1c3531161
commit
9fd45e3951
|
@ -184,6 +184,8 @@ Optimizations
|
|||
|
||||
* GITHUB#12518: Use panama vector API to speed up l2norm calculations (Ben Trent)
|
||||
|
||||
* GITHUB##12371: Lazy computation of similarity score during initializeFromGraph (Jack Wang)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -190,8 +190,6 @@ public final class HnswGraphBuilder<T> {
|
|||
private void initializeFromGraph(
|
||||
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
|
||||
assert hnsw.size() == 0;
|
||||
float[] vectorValue = null;
|
||||
byte[] binaryValue = null;
|
||||
for (int level = 0; level < initializerGraph.numLevels(); level++) {
|
||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||
|
||||
|
@ -205,27 +203,14 @@ public final class HnswGraphBuilder<T> {
|
|||
initializedNodes.add(newOrd);
|
||||
}
|
||||
|
||||
switch (this.vectorEncoding) {
|
||||
case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(newOrd);
|
||||
case BYTE -> binaryValue = (byte[]) vectors.vectorValue(newOrd);
|
||||
}
|
||||
|
||||
NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
|
||||
initializerGraph.seek(level, oldOrd);
|
||||
for (int oldNeighbor = initializerGraph.nextNeighbor();
|
||||
oldNeighbor != NO_MORE_DOCS;
|
||||
oldNeighbor = initializerGraph.nextNeighbor()) {
|
||||
int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
|
||||
float score =
|
||||
switch (this.vectorEncoding) {
|
||||
case FLOAT32 -> this.similarityFunction.compare(
|
||||
vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor));
|
||||
case BYTE -> this.similarityFunction.compare(
|
||||
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
|
||||
};
|
||||
// we are not sure whether the previous graph contains
|
||||
// unchecked nodes, so we have to assume they're all unchecked
|
||||
newNeighbors.addOutOfOrder(newNeighbor, score);
|
||||
// we will compute these scores later when we need to pop out the non-diverse nodes
|
||||
newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -327,7 +312,7 @@ public final class HnswGraphBuilder<T> {
|
|||
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
|
||||
nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
|
||||
if (nbrsOfNbr.size() > maxConnOnLevel) {
|
||||
int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
|
||||
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
|
||||
nbrsOfNbr.removeIndex(indexToRemove);
|
||||
}
|
||||
}
|
||||
|
@ -409,8 +394,27 @@ public final class HnswGraphBuilder<T> {
|
|||
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
|
||||
* neighbours
|
||||
*/
|
||||
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
||||
int[] uncheckedIndexes = neighbors.sort();
|
||||
private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException {
|
||||
float[] vectorValue = null;
|
||||
byte[] binaryValue = null;
|
||||
switch (this.vectorEncoding) {
|
||||
case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(nodeOrd);
|
||||
case BYTE -> binaryValue = (byte[]) vectors.vectorValue(nodeOrd);
|
||||
}
|
||||
float[] finalVectorValue = vectorValue;
|
||||
byte[] finalBinaryValue = binaryValue;
|
||||
int[] uncheckedIndexes =
|
||||
neighbors.sort(
|
||||
nbrOrd -> {
|
||||
float score =
|
||||
switch (this.vectorEncoding) {
|
||||
case FLOAT32 -> this.similarityFunction.compare(
|
||||
finalVectorValue, (float[]) vectorsCopy.vectorValue(nbrOrd));
|
||||
case BYTE -> this.similarityFunction.compare(
|
||||
finalBinaryValue, (byte[]) vectorsCopy.vectorValue(nbrOrd));
|
||||
};
|
||||
return score;
|
||||
});
|
||||
if (uncheckedIndexes == null) {
|
||||
// all nodes are checked, we will directly return the most distant one
|
||||
return neighbors.size() - 1;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
|
@ -31,7 +32,6 @@ import org.apache.lucene.util.ArrayUtil;
|
|||
public class NeighborArray {
|
||||
private final boolean scoresDescOrder;
|
||||
private int size;
|
||||
|
||||
float[] score;
|
||||
int[] node;
|
||||
private int sortedNodeSize;
|
||||
|
@ -67,14 +67,15 @@ public class NeighborArray {
|
|||
++sortedNodeSize;
|
||||
}
|
||||
|
||||
/** Add node and score but do not insert as sorted */
|
||||
/** Add node and newScore but do not insert as sorted */
|
||||
public void addOutOfOrder(int newNode, float newScore) {
|
||||
if (size == node.length) {
|
||||
node = ArrayUtil.grow(node);
|
||||
score = ArrayUtil.growExact(score, node.length);
|
||||
}
|
||||
node[size] = newNode;
|
||||
|
||||
score[size] = newScore;
|
||||
node[size] = newNode;
|
||||
size++;
|
||||
}
|
||||
|
||||
|
@ -85,7 +86,7 @@ public class NeighborArray {
|
|||
* @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
|
||||
* already fully sorted
|
||||
*/
|
||||
public int[] sort() {
|
||||
public int[] sort(ScoringFunction scoringFunction) throws IOException {
|
||||
if (size == sortedNodeSize) {
|
||||
// all nodes checked and sorted
|
||||
return null;
|
||||
|
@ -94,7 +95,8 @@ public class NeighborArray {
|
|||
int[] uncheckedIndexes = new int[size - sortedNodeSize];
|
||||
int count = 0;
|
||||
while (sortedNodeSize != size) {
|
||||
uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
|
||||
uncheckedIndexes[count] =
|
||||
insertSortedInternal(scoringFunction); // sortedNodeSize is increased inside
|
||||
for (int i = 0; i < count; i++) {
|
||||
if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
|
||||
// the previous inserted nodes has been shifted
|
||||
|
@ -108,10 +110,15 @@ public class NeighborArray {
|
|||
}
|
||||
|
||||
/** insert the first unsorted node into its sorted position */
|
||||
private int insertSortedInternal() {
|
||||
private int insertSortedInternal(ScoringFunction scoringFunction) throws IOException {
|
||||
assert sortedNodeSize < size : "Call this method only when there's unsorted node";
|
||||
int tmpNode = node[sortedNodeSize];
|
||||
float tmpScore = score[sortedNodeSize];
|
||||
|
||||
if (Float.isNaN(tmpScore)) {
|
||||
tmpScore = scoringFunction.computeScore(tmpNode);
|
||||
}
|
||||
|
||||
int insertionPoint =
|
||||
scoresDescOrder
|
||||
? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
|
||||
|
@ -127,9 +134,9 @@ public class NeighborArray {
|
|||
}
|
||||
|
||||
/** This method is for test only. */
|
||||
void insertSorted(int newNode, float newScore) {
|
||||
void insertSorted(int newNode, float newScore) throws IOException {
|
||||
addOutOfOrder(newNode, newScore);
|
||||
insertSortedInternal();
|
||||
insertSortedInternal(null);
|
||||
}
|
||||
|
||||
public int size() {
|
||||
|
@ -197,4 +204,20 @@ public class NeighborArray {
|
|||
}
|
||||
return start;
|
||||
}
|
||||
|
||||
/**
|
||||
* ScoringFunction is a lambda function created in HnswGraphBuilder to allow for lazy computation
|
||||
* of distance score.
|
||||
*/
|
||||
interface ScoringFunction {
|
||||
/**
|
||||
* Computes the distance score between the given node ID and the root node of this
|
||||
* NeighborArray.
|
||||
*
|
||||
* @param nodeId The ID of the node for which to compute the distance score.
|
||||
* @return The distance score as a float value.
|
||||
* @throws IOException If an I/O error occurs during computation.
|
||||
*/
|
||||
float computeScore(int nodeId) throws IOException;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,11 +17,12 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestNeighborArray extends LuceneTestCase {
|
||||
|
||||
public void testScoresDescOrder() {
|
||||
public void testScoresDescOrder() throws IOException {
|
||||
NeighborArray neighbors = new NeighborArray(10, true);
|
||||
neighbors.addInOrder(0, 1);
|
||||
neighbors.addInOrder(1, 0.8f);
|
||||
|
@ -70,7 +71,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
|
||||
}
|
||||
|
||||
public void testScoresAscOrder() {
|
||||
public void testScoresAscOrder() throws IOException {
|
||||
NeighborArray neighbors = new NeighborArray(10, false);
|
||||
neighbors.addInOrder(0, 0.1f);
|
||||
neighbors.addInOrder(1, 0.3f);
|
||||
|
@ -119,7 +120,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
|
||||
}
|
||||
|
||||
public void testSortAsc() {
|
||||
public void testSortAsc() throws IOException {
|
||||
NeighborArray neighbors = new NeighborArray(10, false);
|
||||
neighbors.addOutOfOrder(1, 2);
|
||||
// we disallow calling addInOrder after addOutOfOrder even if they're actual in order
|
||||
|
@ -130,7 +131,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
neighbors.addOutOfOrder(7, 8);
|
||||
neighbors.addOutOfOrder(6, 7);
|
||||
neighbors.addOutOfOrder(4, 5);
|
||||
int[] unchecked = neighbors.sort();
|
||||
int[] unchecked = neighbors.sort(null);
|
||||
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
||||
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
|
||||
assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors);
|
||||
|
@ -143,13 +144,13 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
neighbors2.addOutOfOrder(6, 7);
|
||||
neighbors2.addOutOfOrder(5, 6);
|
||||
neighbors2.addOutOfOrder(3, 4);
|
||||
unchecked = neighbors2.sort();
|
||||
unchecked = neighbors2.sort(null);
|
||||
assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
|
||||
assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2);
|
||||
assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
|
||||
}
|
||||
|
||||
public void testSortDesc() {
|
||||
public void testSortDesc() throws IOException {
|
||||
NeighborArray neighbors = new NeighborArray(10, true);
|
||||
neighbors.addOutOfOrder(1, 7);
|
||||
// we disallow calling addInOrder after addOutOfOrder even if they're actual in order
|
||||
|
@ -160,7 +161,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
neighbors.addOutOfOrder(7, 1);
|
||||
neighbors.addOutOfOrder(6, 2);
|
||||
neighbors.addOutOfOrder(4, 4);
|
||||
int[] unchecked = neighbors.sort();
|
||||
int[] unchecked = neighbors.sort(null);
|
||||
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
||||
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
|
||||
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
||||
|
@ -173,12 +174,44 @@ public class TestNeighborArray extends LuceneTestCase {
|
|||
neighbors2.addOutOfOrder(7, 1);
|
||||
neighbors2.addOutOfOrder(6, 2);
|
||||
neighbors2.addOutOfOrder(4, 4);
|
||||
unchecked = neighbors2.sort();
|
||||
unchecked = neighbors2.sort(null);
|
||||
assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
|
||||
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
|
||||
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2);
|
||||
}
|
||||
|
||||
public void testAddwithScoringFunction() throws IOException {
|
||||
NeighborArray neighbors = new NeighborArray(10, true);
|
||||
neighbors.addOutOfOrder(1, Float.NaN);
|
||||
expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
|
||||
neighbors.addOutOfOrder(2, Float.NaN);
|
||||
neighbors.addOutOfOrder(5, Float.NaN);
|
||||
neighbors.addOutOfOrder(3, Float.NaN);
|
||||
neighbors.addOutOfOrder(7, Float.NaN);
|
||||
neighbors.addOutOfOrder(6, Float.NaN);
|
||||
neighbors.addOutOfOrder(4, Float.NaN);
|
||||
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 1);
|
||||
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
||||
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
|
||||
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
||||
}
|
||||
|
||||
public void testAddwithScoringFunctionLargeOrd() throws IOException {
|
||||
NeighborArray neighbors = new NeighborArray(10, true);
|
||||
neighbors.addOutOfOrder(11, Float.NaN);
|
||||
expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
|
||||
neighbors.addOutOfOrder(12, Float.NaN);
|
||||
neighbors.addOutOfOrder(15, Float.NaN);
|
||||
neighbors.addOutOfOrder(13, Float.NaN);
|
||||
neighbors.addOutOfOrder(17, Float.NaN);
|
||||
neighbors.addOutOfOrder(16, Float.NaN);
|
||||
neighbors.addOutOfOrder(14, Float.NaN);
|
||||
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 11);
|
||||
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
||||
assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors);
|
||||
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
||||
}
|
||||
|
||||
private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
|
||||
for (int i = 0; i < scores.length; i++) {
|
||||
assertEquals(scores[i], neighbors.score[i], 0.01f);
|
||||
|
|
Loading…
Reference in New Issue