mirror of https://github.com/apache/lucene.git
LUCENE-10592 Better estimate memory for HNSW graph (#11743)
Better estimate memory used for OnHeapHnswGraph, as well as add tests. Also don't overallocate arrays in NeighborArray Relates to #992
This commit is contained in:
parent
49b596ef02
commit
0ea8035612
|
@ -173,7 +173,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE -> writeByteVectors(fieldData);
|
||||
case FLOAT32 -> writeFloat32Vectors(fieldData);
|
||||
}
|
||||
;
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
|
|
|
@ -46,8 +46,8 @@ public class NeighborArray {
|
|||
* nodes.
|
||||
*/
|
||||
public void add(int newNode, float newScore) {
|
||||
if (size == node.length - 1) {
|
||||
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
||||
if (size == node.length) {
|
||||
node = ArrayUtil.grow(node);
|
||||
score = ArrayUtil.growExact(score, node.length);
|
||||
}
|
||||
if (size > 0) {
|
||||
|
@ -63,8 +63,8 @@ public class NeighborArray {
|
|||
|
||||
/** Add a new node to the NeighborArray into a correct sort position according to its score. */
|
||||
public void insertSorted(int newNode, float newScore) {
|
||||
if (size == node.length - 1) {
|
||||
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
||||
if (size == node.length) {
|
||||
node = ArrayUtil.grow(node);
|
||||
score = ArrayUtil.growExact(score, node.length);
|
||||
}
|
||||
int insertionPoint =
|
||||
|
@ -104,8 +104,8 @@ public class NeighborArray {
|
|||
}
|
||||
|
||||
public void removeIndex(int idx) {
|
||||
System.arraycopy(node, idx + 1, node, idx, size - idx);
|
||||
System.arraycopy(score, idx + 1, score, idx, size - idx);
|
||||
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
|
||||
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
|
||||
size--;
|
||||
}
|
||||
|
||||
|
|
|
@ -175,20 +175,28 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
long neighborArrayBytes0 =
|
||||
nsize0 * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
|
||||
+ Integer.BYTES * 2;
|
||||
long neighborArrayBytes =
|
||||
nsize * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
|
||||
+ Integer.BYTES * 2;
|
||||
long total = 0;
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
int numNodesOnLevel = graph.get(l).size();
|
||||
if (l == 0) {
|
||||
total += numNodesOnLevel * neighborArrayBytes0; // for graph;
|
||||
total +=
|
||||
numNodesOnLevel * neighborArrayBytes0
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
|
||||
} else {
|
||||
total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel
|
||||
total += numNodesOnLevel * neighborArrayBytes; // for graph;
|
||||
total +=
|
||||
nodesByLevel.get(l).length * Integer.BYTES
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel
|
||||
total +=
|
||||
numNodesOnLevel * neighborArrayBytes
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
|
||||
}
|
||||
}
|
||||
return total;
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -59,6 +62,7 @@ import org.apache.lucene.util.BitSet;
|
|||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.junit.Before;
|
||||
|
@ -71,15 +75,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
@Before
|
||||
public void setup() {
|
||||
similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
|
||||
vectorEncoding =
|
||||
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
|
||||
} else {
|
||||
vectorEncoding = VectorEncoding.FLOAT32;
|
||||
}
|
||||
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
|
||||
}
|
||||
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
|
@ -158,8 +155,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
int M = random().nextInt(10) + 5;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
long seed = random().nextLong();
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
IndexWriterConfig iwc =
|
||||
|
@ -475,6 +471,27 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
0));
|
||||
}
|
||||
|
||||
public void testRamUsageEstimate() throws IOException {
|
||||
int size = atLeast(2000);
|
||||
int dim = randomIntBetween(100, 1024);
|
||||
int M = randomIntBetween(4, 96);
|
||||
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
|
||||
TestHnswGraph.RandomVectorValues vectors =
|
||||
new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
|
||||
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
|
||||
long actual = ramUsed(hnsw);
|
||||
|
||||
assertEquals((double) actual, (double) estimated, (double) actual * 0.3);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testDiversity() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
|
|
Loading…
Reference in New Issue