1
0
mirror of https://github.com/apache/lucene.git synced 2025-02-21 01:18:45 +00:00

LUCENE-10592 Better estimate memory for HNSW graph ()

Better estimate memory used for OnHeapHnswGraph,
as well as add tests.

Also don't overallocate arrays in NeighborArray

Relates to 
This commit is contained in:
Mayya Sharipova 2022-09-08 16:54:29 -04:00 committed by GitHub
parent 49b596ef02
commit 0ea8035612
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 24 deletions
lucene/core/src
java/org/apache/lucene
test/org/apache/lucene/util/hnsw

@ -173,7 +173,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
}
;
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph

@ -46,8 +46,8 @@ public class NeighborArray {
* nodes.
*/
public void add(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
if (size > 0) {
@ -63,8 +63,8 @@ public class NeighborArray {
/** Add a new node to the NeighborArray into a correct sort position according to its score. */
public void insertSorted(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
int insertionPoint =
@ -104,8 +104,8 @@ public class NeighborArray {
}
public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx);
System.arraycopy(score, idx + 1, score, idx, size - idx);
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
size--;
}

@ -175,20 +175,28 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ Integer.BYTES * 2;
long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ Integer.BYTES * 2;
long total = 0;
for (int l = 0; l < numLevels; l++) {
int numNodesOnLevel = graph.get(l).size();
if (l == 0) {
total += numNodesOnLevel * neighborArrayBytes0; // for graph;
total +=
numNodesOnLevel * neighborArrayBytes0
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
} else {
total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel
total += numNodesOnLevel * neighborArrayBytes; // for graph;
total +=
nodesByLevel.get(l).length * Integer.BYTES
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel
total +=
numNodesOnLevel * neighborArrayBytes
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
}
}
return total;

@ -17,9 +17,12 @@
package org.apache.lucene.util.hnsw;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@ -59,6 +62,7 @@ import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.junit.Before;
@ -71,15 +75,8 @@ public class TestHnswGraph extends LuceneTestCase {
@Before
public void setup() {
similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
vectorEncoding =
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
} else {
vectorEncoding = VectorEncoding.FLOAT32;
}
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
}
// test writing out and reading in a graph gives the expected graph
@ -158,8 +155,7 @@ public class TestHnswGraph extends LuceneTestCase {
int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5;
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
long seed = random().nextLong();
HnswGraphBuilder.randSeed = seed;
IndexWriterConfig iwc =
@ -475,6 +471,27 @@ public class TestHnswGraph extends LuceneTestCase {
0));
}
public void testRamUsageEstimate() throws IOException {
int size = atLeast(2000);
int dim = randomIntBetween(100, 1024);
int M = randomIntBetween(4, 96);
VectorSimilarityFunction similarityFunction =
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
TestHnswGraph.RandomVectorValues vectors =
new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
long actual = ramUsed(hnsw);
assertEquals((double) actual, (double) estimated, (double) actual * 0.3);
}
@SuppressWarnings("unchecked")
public void testDiversity() throws IOException {
vectorEncoding = randomVectorEncoding();