LUCENE-10592 Better estimate memory for HNSW graph (#11743)

Better estimate memory used for OnHeapHnswGraph,
as well as add tests.

Also don't overallocate arrays in NeighborArray

Relates to #992
This commit is contained in:
Mayya Sharipova 2022-09-08 16:54:29 -04:00 committed by GitHub
parent 49b596ef02
commit 0ea8035612
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 24 deletions

View File

@ -173,7 +173,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
}
;
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph

View File

@ -46,8 +46,8 @@ public class NeighborArray {
* nodes.
*/
public void add(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
if (size > 0) {
@ -63,8 +63,8 @@ public class NeighborArray {
/** Add a new node to the NeighborArray into a correct sort position according to its score. */
public void insertSorted(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
int insertionPoint =
@ -104,8 +104,8 @@ public class NeighborArray {
}
public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx);
System.arraycopy(score, idx + 1, score, idx, size - idx);
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
size--;
}

View File

@ -175,20 +175,28 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ Integer.BYTES * 2;
long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ Integer.BYTES * 2;
long total = 0;
for (int l = 0; l < numLevels; l++) {
int numNodesOnLevel = graph.get(l).size();
if (l == 0) {
total += numNodesOnLevel * neighborArrayBytes0; // for graph;
total +=
numNodesOnLevel * neighborArrayBytes0
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
} else {
total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel
total += numNodesOnLevel * neighborArrayBytes; // for graph;
total +=
nodesByLevel.get(l).length * Integer.BYTES
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel
total +=
numNodesOnLevel * neighborArrayBytes
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
}
}
return total;

View File

@ -17,9 +17,12 @@
package org.apache.lucene.util.hnsw;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@ -59,6 +62,7 @@ import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.junit.Before;
@ -71,15 +75,8 @@ public class TestHnswGraph extends LuceneTestCase {
@Before
public void setup() {
similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
vectorEncoding =
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
} else {
vectorEncoding = VectorEncoding.FLOAT32;
}
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
}
// test writing out and reading in a graph gives the expected graph
@ -158,8 +155,7 @@ public class TestHnswGraph extends LuceneTestCase {
int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5;
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
long seed = random().nextLong();
HnswGraphBuilder.randSeed = seed;
IndexWriterConfig iwc =
@ -475,6 +471,27 @@ public class TestHnswGraph extends LuceneTestCase {
0));
}
public void testRamUsageEstimate() throws IOException {
int size = atLeast(2000);
int dim = randomIntBetween(100, 1024);
int M = randomIntBetween(4, 96);
VectorSimilarityFunction similarityFunction =
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
TestHnswGraph.RandomVectorValues vectors =
new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
long actual = ramUsed(hnsw);
assertEquals((double) actual, (double) estimated, (double) actual * 0.3);
}
@SuppressWarnings("unchecked")
public void testDiversity() throws IOException {
vectorEncoding = randomVectorEncoding();