From 0ea8035612344e861bbdeaffca2255e5e1480a6a Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Thu, 8 Sep 2022 16:54:29 -0400 Subject: [PATCH] LUCENE-10592 Better estimate memory for HNSW graph (#11743) Better estimate memory used for OnHeapHnswGraph, as well as add tests. Also don't overallocate arrays in NeighborArray Relates to #992 --- .../lucene94/Lucene94HnswVectorsWriter.java | 1 - .../lucene/util/hnsw/NeighborArray.java | 12 +++--- .../lucene/util/hnsw/OnHeapHnswGraph.java | 20 +++++++--- .../lucene/util/hnsw/TestHnswGraph.java | 39 +++++++++++++------ 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java index d57f3fc68a3..e54661552af 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -173,7 +173,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { case BYTE -> writeByteVectors(fieldData); case FLOAT32 -> writeFloat32Vectors(fieldData); } - ; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; // write graph diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index 78224ed2358..ec1b5ec3e89 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -46,8 +46,8 @@ public class NeighborArray { * nodes. */ public void add(int newNode, float newScore) { - if (size == node.length - 1) { - node = ArrayUtil.grow(node, (size + 1) * 3 / 2); + if (size == node.length) { + node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); } if (size > 0) { @@ -63,8 +63,8 @@ public class NeighborArray { /** Add a new node to the NeighborArray into a correct sort position according to its score. */ public void insertSorted(int newNode, float newScore) { - if (size == node.length - 1) { - node = ArrayUtil.grow(node, (size + 1) * 3 / 2); + if (size == node.length) { + node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); } int insertionPoint = @@ -104,8 +104,8 @@ public class NeighborArray { } public void removeIndex(int idx) { - System.arraycopy(node, idx + 1, node, idx, size - idx); - System.arraycopy(score, idx + 1, score, idx, size - idx); + System.arraycopy(node, idx + 1, node, idx, size - idx - 1); + System.arraycopy(score, idx + 1, score, idx, size - idx - 1); size--; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 8cf6b54654c..78137c2a630 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -175,20 +175,28 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { long neighborArrayBytes0 = nsize0 * (Integer.BYTES + Float.BYTES) + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + + Integer.BYTES * 2; long neighborArrayBytes = nsize * (Integer.BYTES + Float.BYTES) + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; - + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + + Integer.BYTES * 2; long total = 0; for (int l = 0; l < numLevels; l++) { int numNodesOnLevel = graph.get(l).size(); if (l == 0) { - total += numNodesOnLevel * neighborArrayBytes0; // for graph; + total += + numNodesOnLevel * neighborArrayBytes0 + + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; } else { - total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel - total += numNodesOnLevel * neighborArrayBytes; // for graph; + total += + nodesByLevel.get(l).length * Integer.BYTES + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel + total += + numNodesOnLevel * neighborArrayBytes + + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; } } return total; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index 1816de726e4..7852e706157 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -17,9 +17,12 @@ package org.apache.lucene.util.hnsw; +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.tests.util.RamUsageTester.ramUsed; import static org.apache.lucene.util.VectorUtil.toBytesRef; +import com.carrotsearch.randomizedtesting.RandomizedTest; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -59,6 +62,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.junit.Before; @@ -71,15 +75,8 @@ public class TestHnswGraph extends LuceneTestCase { @Before public void setup() { - similarityFunction = - VectorSimilarityFunction.values()[ - random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; - if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) { - vectorEncoding = - VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1]; - } else { - vectorEncoding = VectorEncoding.FLOAT32; - } + similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); + vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values()); } // test writing out and reading in a graph gives the expected graph @@ -158,8 +155,7 @@ public class TestHnswGraph extends LuceneTestCase { int M = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 5; VectorSimilarityFunction similarityFunction = - VectorSimilarityFunction.values()[ - random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; + RandomizedTest.randomFrom(VectorSimilarityFunction.values()); long seed = random().nextLong(); HnswGraphBuilder.randSeed = seed; IndexWriterConfig iwc = @@ -475,6 +471,27 @@ public class TestHnswGraph extends LuceneTestCase { 0)); } + public void testRamUsageEstimate() throws IOException { + int size = atLeast(2000); + int dim = randomIntBetween(100, 1024); + int M = randomIntBetween(4, 96); + + VectorSimilarityFunction similarityFunction = + RandomizedTest.randomFrom(VectorSimilarityFunction.values()); + VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values()); + TestHnswGraph.RandomVectorValues vectors = + new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random()); + + HnswGraphBuilder builder = + HnswGraphBuilder.create( + vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong()); + OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + long estimated = RamUsageEstimator.sizeOfObject(hnsw); + long actual = ramUsed(hnsw); + + assertEquals((double) actual, (double) estimated, (double) actual * 0.3); + } + @SuppressWarnings("unchecked") public void testDiversity() throws IOException { vectorEncoding = randomVectorEncoding();