From eb5bdd7d155bd9fd29d388521e148ef9c6d67db2 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Mon, 7 Feb 2022 13:21:15 -0800 Subject: [PATCH] Rename KnnGraphValues -> HnswGraph (#645) This PR proposes some renames to clarify the code structure. The top-level `KnnGraphValues` is renamed to `HnswGraph`, since it now represents a hierarchical graph. It's also moved from `org.apache.lucene.index` to the `hnsw` package. Other renames: * The old `HnswGraph` -> `OnHeapHnswGraph` * `IndexedKnnGraphValues` -> `OffHeapHnswGraph` (to match `OffHeapVectorValues`) --- .../lucene90/Lucene90HnswGraphBuilder.java | 10 +- .../lucene90/Lucene90HnswVectorsReader.java | 16 +- ...raph.java => Lucene90OnHeapHnswGraph.java} | 34 +-- .../lucene90/Lucene90HnswVectorsWriter.java | 2 +- .../lucene91/Lucene91HnswVectorsReader.java | 18 +- .../lucene91/Lucene91HnswVectorsWriter.java | 12 +- .../apache/lucene/index/KnnGraphValues.java | 151 ------------ .../apache/lucene/util/hnsw/HnswGraph.java | 227 ++++++++---------- .../lucene/util/hnsw/HnswGraphBuilder.java | 6 +- .../lucene/util/hnsw/HnswGraphSearcher.java | 25 +- .../lucene/util/hnsw/OnHeapHnswGraph.java | 169 +++++++++++++ .../org/apache/lucene/index/TestKnnGraph.java | 9 +- .../lucene/util/hnsw/KnnGraphTester.java | 8 +- .../lucene/util/hnsw/TestHnswGraph.java | 21 +- 14 files changed, 341 insertions(+), 367 deletions(-) rename lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/{Lucene90HnswGraph.java => Lucene90OnHeapHnswGraph.java} (80%) delete mode 100644 lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 7fda65e23c9..97fdd1c3999 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -30,7 +30,7 @@ import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.NeighborQueue; /** - * Builder for HNSW graph. See {@link Lucene90HnswGraph} for a gloss on the algorithm and the + * Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the * meaning of the hyperparameters. * *

This class is preserved here only for tests. @@ -53,7 +53,7 @@ public final class Lucene90HnswGraphBuilder { private final RandomAccessVectorValues vectorValues; private final SplittableRandom random; private final BoundsChecker bound; - final Lucene90HnswGraph hnsw; + final Lucene90OnHeapHnswGraph hnsw; private InfoStream infoStream = InfoStream.getDefault(); @@ -90,7 +90,7 @@ public final class Lucene90HnswGraphBuilder { } this.maxConn = maxConn; this.beamWidth = beamWidth; - this.hnsw = new Lucene90HnswGraph(maxConn); + this.hnsw = new Lucene90OnHeapHnswGraph(maxConn); bound = BoundsChecker.create(similarityFunction.reversed); random = new SplittableRandom(seed); scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1)); @@ -104,7 +104,7 @@ public final class Lucene90HnswGraphBuilder { * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * accessor for the vectors */ - public Lucene90HnswGraph build(RandomAccessVectorValues vectors) throws IOException { + public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -143,7 +143,7 @@ public final class Lucene90HnswGraphBuilder { void addGraphNode(float[] value) throws IOException { // We pass 'null' for acceptOrds because there are no deletions while building the graph NeighborQueue candidates = - Lucene90HnswGraph.search( + Lucene90OnHeapHnswGraph.search( value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random); int node = hnsw.addNode(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 7669a8d38ff..85ecf5106dc 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -31,7 +31,6 @@ import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentReadState; @@ -47,6 +46,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.NeighborQueue; /** @@ -243,7 +243,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { // use a seed that is fixed for the index so we get reproducible results for the same query final SplittableRandom random = new SplittableRandom(checksumSeed); NeighborQueue results = - Lucene90HnswGraph.search( + Lucene90OnHeapHnswGraph.search( target, k, k, @@ -291,7 +291,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { } /** Get knn graph values; used for testing */ - public KnnGraphValues getGraphValues(String field) throws IOException { + public HnswGraph getGraphValues(String field) throws IOException { FieldInfo info = fieldInfos.fieldInfo(field); if (info == null) { throw new IllegalArgumentException("No such field '" + field + "'"); @@ -300,14 +300,14 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { if (entry != null && entry.indexDataLength > 0) { return getGraphValues(entry); } else { - return KnnGraphValues.EMPTY; + return HnswGraph.EMPTY; } } - private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException { + private HnswGraph getGraphValues(FieldEntry entry) throws IOException { IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength); - return new IndexedKnnGraphReader(entry, bytesSlice); + return new OffHeapHnswGraph(entry, bytesSlice); } @Override @@ -465,7 +465,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { } /** Read the nearest-neighbors graph from the index input */ - private static final class IndexedKnnGraphReader extends KnnGraphValues { + private static final class OffHeapHnswGraph extends HnswGraph { final FieldEntry entry; final IndexInput dataIn; @@ -474,7 +474,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { int arcUpTo; int arc; - IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) { + OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) { this.entry = entry; this.dataIn = dataIn; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java similarity index 80% rename from lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraph.java rename to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index d8d28eca16b..340bcf24199 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -23,42 +23,20 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.SplittableRandom; -import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; import org.apache.lucene.util.SparseFixedBitSet; import org.apache.lucene.util.hnsw.BoundsChecker; +import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.NeighborQueue; /** - * Navigable Small-world graph. Provides efficient approximate nearest neighbor search for high - * dimensional vectors. See Approximate nearest - * neighbor algorithm based on navigable small world graphs [2014] and this paper [2018] for details. - * - *

The nomenclature is a bit different here from what's used in those papers: - * - *

Hyperparameters

- * - * - * - *

Note: The graph may be searched by multiple threads concurrently, but updates are not - * thread-safe. Also note: there is no notion of deletions. Document searching built on top of this - * must do its own deletion-filtering. - * - *

Graph building logic is preserved here only for tests. + * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to + * construct the HNSW graph before it's written to the index. */ -public final class Lucene90HnswGraph extends KnnGraphValues { +public final class Lucene90OnHeapHnswGraph extends HnswGraph { private final int maxConn; @@ -71,7 +49,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues { private int upto; private NeighborArray cur; - Lucene90HnswGraph(int maxConn) { + Lucene90OnHeapHnswGraph(int maxConn) { graph = new ArrayList<>(); // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM @@ -100,7 +78,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues { int numSeed, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, - KnnGraphValues graphValues, + HnswGraph graphValues, Bits acceptOrds, SplittableRandom random) throws IOException { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index d76e5efb635..a71f5efb14f 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -241,7 +241,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter { beamWidth, Lucene90HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - Lucene90HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); + Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); for (int ord = 0; ord < offsets.length; ord++) { // write graph diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsReader.java index f196160bc8a..85247df49f0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsReader.java @@ -30,7 +30,6 @@ import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentReadState; @@ -46,6 +45,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.NeighborQueue; @@ -235,7 +235,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { k, vectorValues, fieldEntry.similarityFunction, - getGraphValues(fieldEntry), + getGraph(fieldEntry), getAcceptOrds(acceptDocs, fieldEntry)); int i = 0; @@ -277,23 +277,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { } /** Get knn graph values; used for testing */ - public KnnGraphValues getGraphValues(String field) throws IOException { + public HnswGraph getGraph(String field) throws IOException { FieldInfo info = fieldInfos.fieldInfo(field); if (info == null) { throw new IllegalArgumentException("No such field '" + field + "'"); } FieldEntry entry = fields.get(field); if (entry != null && entry.vectorIndexLength > 0) { - return getGraphValues(entry); + return getGraph(entry); } else { - return KnnGraphValues.EMPTY; + return HnswGraph.EMPTY; } } - private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException { + private HnswGraph getGraph(FieldEntry entry) throws IOException { IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength); - return new IndexedKnnGraphReader(entry, bytesSlice); + return new OffHeapHnswGraph(entry, bytesSlice); } @Override @@ -478,7 +478,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { } /** Read the nearest-neighbors graph from the index input */ - private static final class IndexedKnnGraphReader extends KnnGraphValues { + private static final class OffHeapHnswGraph extends HnswGraph { final IndexInput dataIn; final int[][] nodesByLevel; @@ -492,7 +492,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { int arcUpTo; int arc; - IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) { + OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) { this.dataIn = dataIn; this.nodesByLevel = entry.nodesByLevel; this.numLevels = entry.numLevels; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsWriter.java index f56f1d72e25..2b3eeb448a3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -26,7 +26,6 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.KnnGraphValues.NodesIterator; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; @@ -36,9 +35,10 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; +import org.apache.lucene.util.hnsw.OnHeapHnswGraph; /** * Writes vector values and knn graphs to index segments. @@ -141,7 +141,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors = new Lucene91HnswVectorsReader.OffHeapVectorValues( vectors.dimension(), docIds, vectorDataInput); - HnswGraph graph = + OnHeapHnswGraph graph = offHeapVectors.size() == 0 ? null : writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction()); @@ -197,7 +197,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { long vectorIndexOffset, long vectorIndexLength, int[] docIds, - HnswGraph graph) + OnHeapHnswGraph graph) throws IOException { meta.writeInt(field.number); meta.writeInt(field.getVectorSimilarityFunction().ordinal()); @@ -232,7 +232,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { } } - private HnswGraph writeGraph( + private OnHeapHnswGraph writeGraph( RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { @@ -241,7 +241,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { new HnswGraphBuilder( vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); + OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); // write vectors' neighbours on each level into the vectorIndex file int countOnLevel0 = graph.size(); diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java deleted file mode 100644 index 07feb2a8136..00000000000 --- a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.index; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import java.util.NoSuchElementException; -import java.util.PrimitiveIterator; - -/** - * Access to per-document neighbor lists in a (hierarchical) knn search graph. - * - * @lucene.experimental - */ -public abstract class KnnGraphValues { - - /** Sole constructor */ - protected KnnGraphValues() {} - - /** - * Move the pointer to exactly the given {@code level}'s {@code target}. After this method - * returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals. - * - * @param level level of the graph - * @param target ordinal of a node in the graph, must be ≥ 0 and < {@link - * VectorValues#size()}. - */ - public abstract void seek(int level, int target) throws IOException; - - /** Returns the number of nodes in the graph */ - public abstract int size(); - - /** - * Iterates over the neighbor list. It is illegal to call this method after it returns - * NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator. - * - * @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete. - */ - public abstract int nextNeighbor() throws IOException; - - /** Returns the number of levels of the graph */ - public abstract int numLevels() throws IOException; - - /** Returns graph's entry point on the top level * */ - public abstract int entryNode() throws IOException; - - /** - * Get all nodes on a given level as node 0th ordinals - * - * @param level level for which to get all nodes - * @return an iterator over nodes where {@code nextInt} returns a next node on the level - */ - public abstract NodesIterator getNodesOnLevel(int level) throws IOException; - - /** Empty graph value */ - public static KnnGraphValues EMPTY = - new KnnGraphValues() { - - @Override - public int nextNeighbor() { - return NO_MORE_DOCS; - } - - @Override - public void seek(int level, int target) {} - - @Override - public int size() { - return 0; - } - - @Override - public int numLevels() { - return 0; - } - - @Override - public int entryNode() { - return 0; - } - - @Override - public NodesIterator getNodesOnLevel(int level) { - return NodesIterator.EMPTY; - } - }; - - /** - * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total - * number of nodes to be iterated over. - */ - public static final class NodesIterator implements PrimitiveIterator.OfInt { - static NodesIterator EMPTY = new NodesIterator(0); - - private final int[] nodes; - private final int size; - int cur = 0; - - /** Constructor for iterator based on the nodes array up to the size */ - public NodesIterator(int[] nodes, int size) { - assert nodes != null; - assert size <= nodes.length; - this.nodes = nodes; - this.size = size; - } - - /** Constructor for iterator based on the size */ - public NodesIterator(int size) { - this.nodes = null; - this.size = size; - } - - @Override - public int nextInt() { - if (hasNext() == false) { - throw new NoSuchElementException(); - } - if (nodes == null) { - return cur++; - } else { - return nodes[cur++]; - } - } - - @Override - public boolean hasNext() { - return cur < size; - } - - /** The number of elements in this iterator * */ - public int size() { - return size; - } - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index 58693956c19..2a222d23e07 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -19,11 +19,10 @@ package org.apache.lucene.util.hnsw; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.apache.lucene.index.KnnGraphValues; -import org.apache.lucene.util.ArrayUtil; +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.PrimitiveIterator; +import org.apache.lucene.index.VectorValues; /** * Hierarchical Navigable Small World graph. Provides efficient approximate nearest neighbor search @@ -47,142 +46,124 @@ import org.apache.lucene.util.ArrayUtil; * thread-safe. The search method optionally takes a set of "accepted nodes", which can be used to * exclude deleted documents. */ -public final class HnswGraph extends KnnGraphValues { +public abstract class HnswGraph { - private final int maxConn; - private int numLevels; // the current number of levels in the graph - private int entryNode; // the current graph entry node on the top level - - // Nodes by level expressed as the level 0's nodes' ordinals. - // As level 0 contains all nodes, nodesByLevel.get(0) is null. - private final List nodesByLevel; - - // graph is a list of graph levels. - // Each level is represented as List – nodes' connections on this level. - // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors - // added to HnswBuilder, and the node values are the ordinals of those vectors. - // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. - private final List> graph; - - // KnnGraphValues iterator members - private int upto; - private NeighborArray cur; - - HnswGraph(int maxConn, int levelOfFirstNode) { - this.maxConn = maxConn; - this.numLevels = levelOfFirstNode + 1; - this.graph = new ArrayList<>(numLevels); - this.entryNode = 0; - for (int i = 0; i < numLevels; i++) { - graph.add(new ArrayList<>()); - // Typically with diversity criteria we see nodes not fully occupied; - // average fanout seems to be about 1/2 maxConn. - // There is some indexing time penalty for under-allocating, but saves RAM - graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4))); - } - - this.nodesByLevel = new ArrayList<>(numLevels); - nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes - for (int l = 1; l < numLevels; l++) { - nodesByLevel.add(new int[] {0}); - } - } + /** Sole constructor */ + protected HnswGraph() {} /** - * Returns the {@link NeighborQueue} connected to the given node. + * Move the pointer to exactly the given {@code level}'s {@code target}. After this method + * returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals. * * @param level level of the graph - * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. + * @param target ordinal of a node in the graph, must be ≥ 0 and < {@link + * VectorValues#size()}. */ - public NeighborArray getNeighbors(int level, int node) { - if (level == 0) { - return graph.get(level).get(node); - } - int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node); - assert nodeIndex >= 0; - return graph.get(level).get(nodeIndex); - } + public abstract void seek(int level, int target) throws IOException; - @Override - public int size() { - return graph.get(0).size(); // all nodes are located on the 0th level - } + /** Returns the number of nodes in the graph */ + public abstract int size(); /** - * Add node on the given level + * Iterates over the neighbor list. It is illegal to call this method after it returns + * NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator. * - * @param level level to add a node on - * @param node the node to add, represented as an ordinal on the level 0. + * @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete. */ - public void addNode(int level, int node) { - if (level > 0) { - // if the new node introduces a new level, add more levels to the graph, - // and make this node the graph's new entry point - if (level >= numLevels) { - for (int i = numLevels; i <= level; i++) { - graph.add(new ArrayList<>()); - nodesByLevel.add(new int[] {node}); + public abstract int nextNeighbor() throws IOException; + + /** Returns the number of levels of the graph */ + public abstract int numLevels() throws IOException; + + /** Returns graph's entry point on the top level * */ + public abstract int entryNode() throws IOException; + + /** + * Get all nodes on a given level as node 0th ordinals + * + * @param level level for which to get all nodes + * @return an iterator over nodes where {@code nextInt} returns a next node on the level + */ + public abstract NodesIterator getNodesOnLevel(int level) throws IOException; + + /** Empty graph value */ + public static HnswGraph EMPTY = + new HnswGraph() { + + @Override + public int nextNeighbor() { + return NO_MORE_DOCS; } - numLevels = level + 1; - entryNode = node; + + @Override + public void seek(int level, int target) {} + + @Override + public int size() { + return 0; + } + + @Override + public int numLevels() { + return 0; + } + + @Override + public int entryNode() { + return 0; + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + return NodesIterator.EMPTY; + } + }; + + /** + * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total + * number of nodes to be iterated over. + */ + public static final class NodesIterator implements PrimitiveIterator.OfInt { + static NodesIterator EMPTY = new NodesIterator(0); + + private final int[] nodes; + private final int size; + int cur = 0; + + /** Constructor for iterator based on the nodes array up to the size */ + public NodesIterator(int[] nodes, int size) { + assert nodes != null; + assert size <= nodes.length; + this.nodes = nodes; + this.size = size; + } + + /** Constructor for iterator based on the size */ + public NodesIterator(int size) { + this.nodes = null; + this.size = size; + } + + @Override + public int nextInt() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + if (nodes == null) { + return cur++; } else { - // Add this node id to this level's nodes - int[] nodes = nodesByLevel.get(level); - int idx = graph.get(level).size(); - if (idx < nodes.length) { - nodes[idx] = node; - } else { - nodes = ArrayUtil.grow(nodes); - nodes[idx] = node; - nodesByLevel.set(level, nodes); - } + return nodes[cur++]; } } - graph.get(level).add(new NeighborArray(maxConn + 1)); - } - - @Override - public void seek(int level, int targetNode) { - cur = getNeighbors(level, targetNode); - upto = -1; - } - - @Override - public int nextNeighbor() { - if (++upto < cur.size()) { - return cur.node[upto]; + @Override + public boolean hasNext() { + return cur < size; } - return NO_MORE_DOCS; - } - /** - * Returns the current number of levels in the graph - * - * @return the current number of levels in the graph - */ - @Override - public int numLevels() { - return numLevels; - } - - /** - * Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th - * level - * - * @return the graph's current entry node on the top level - */ - @Override - public int entryNode() { - return entryNode; - } - - @Override - public NodesIterator getNodesOnLevel(int level) { - if (level == 0) { - return new NodesIterator(size()); - } else { - return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + /** The number of elements in this iterator * */ + public int size() { + return size; } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 43efcb314c8..cc74d01bb53 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -54,7 +54,7 @@ public final class HnswGraphBuilder { private final BoundsChecker bound; private final HnswGraphSearcher graphSearcher; - final HnswGraph hnsw; + final OnHeapHnswGraph hnsw; private InfoStream infoStream = InfoStream.getDefault(); @@ -95,7 +95,7 @@ public final class HnswGraphBuilder { this.ml = 1 / Math.log(1.0 * maxConn); this.random = new SplittableRandom(seed); int levelOfFirstNode = getRandomGraphLevel(ml, random); - this.hnsw = new HnswGraph(maxConn, levelOfFirstNode); + this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode); this.graphSearcher = new HnswGraphSearcher( similarityFunction, @@ -113,7 +113,7 @@ public final class HnswGraphBuilder { * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * accessor for the vectors */ - public HnswGraph build(RandomAccessVectorValues vectors) throws IOException { + public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 71475c1a773..23806a8d5d9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.BitSet; @@ -62,8 +61,8 @@ public final class HnswGraphSearcher { * @param topK the number of nodes to be returned * @param vectors the vector values * @param similarityFunction the similarity function to compare vectors - * @param graphValues the graph values. May represent the entire graph, or a level in a - * hierarchical graph. + * @param graph the graph values. May represent the entire graph, or a level in a hierarchical + * graph. * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. * @return a priority queue holding the closest neighbors found @@ -73,7 +72,7 @@ public final class HnswGraphSearcher { int topK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, - KnnGraphValues graphValues, + HnswGraph graph, Bits acceptOrds) throws IOException { HnswGraphSearcher graphSearcher = @@ -82,12 +81,12 @@ public final class HnswGraphSearcher { new NeighborQueue(topK, similarityFunction.reversed == false), new SparseFixedBitSet(vectors.size())); NeighborQueue results; - int[] eps = new int[] {graphValues.entryNode()}; - for (int level = graphValues.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graphValues, null); + int[] eps = new int[] {graph.entryNode()}; + for (int level = graph.numLevels() - 1; level >= 1; level--) { + results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null); eps[0] = results.pop(); } - results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graphValues, acceptOrds); + results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds); return results; } @@ -99,7 +98,7 @@ public final class HnswGraphSearcher { * @param level level to search * @param eps the entry points for search at this level expressed as level 0th ordinals * @param vectors vector values - * @param graphValues the graph values + * @param graph the graph values * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. * @return a priority queue holding the closest neighbors found @@ -110,10 +109,10 @@ public final class HnswGraphSearcher { int level, final int[] eps, RandomAccessVectorValues vectors, - KnnGraphValues graphValues, + HnswGraph graph, Bits acceptOrds) throws IOException { - int size = graphValues.size(); + int size = graph.size(); NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed); clearScratchState(); @@ -140,9 +139,9 @@ public final class HnswGraphSearcher { break; } int topCandidateNode = candidates.pop(); - graphValues.seek(level, topCandidateNode); + graph.seek(level, topCandidateNode); int friendOrd; - while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) { + while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) { assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; if (visited.getAndSet(friendOrd)) { continue; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java new file mode 100644 index 00000000000..09f8afa7aa7 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.util.ArrayUtil; + +/** + * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to + * construct the HNSW graph before it's written to the index. + */ +public final class OnHeapHnswGraph extends HnswGraph { + + private final int maxConn; + private int numLevels; // the current number of levels in the graph + private int entryNode; // the current graph entry node on the top level + + // Nodes by level expressed as the level 0's nodes' ordinals. + // As level 0 contains all nodes, nodesByLevel.get(0) is null. + private final List nodesByLevel; + + // graph is a list of graph levels. + // Each level is represented as List – nodes' connections on this level. + // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors + // added to HnswBuilder, and the node values are the ordinals of those vectors. + // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. + private final List> graph; + + // KnnGraphValues iterator members + private int upto; + private NeighborArray cur; + + OnHeapHnswGraph(int maxConn, int levelOfFirstNode) { + this.maxConn = maxConn; + this.numLevels = levelOfFirstNode + 1; + this.graph = new ArrayList<>(numLevels); + this.entryNode = 0; + for (int i = 0; i < numLevels; i++) { + graph.add(new ArrayList<>()); + // Typically with diversity criteria we see nodes not fully occupied; + // average fanout seems to be about 1/2 maxConn. + // There is some indexing time penalty for under-allocating, but saves RAM + graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4))); + } + + this.nodesByLevel = new ArrayList<>(numLevels); + nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes + for (int l = 1; l < numLevels; l++) { + nodesByLevel.add(new int[] {0}); + } + } + + /** + * Returns the {@link NeighborQueue} connected to the given node. + * + * @param level level of the graph + * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. + */ + public NeighborArray getNeighbors(int level, int node) { + if (level == 0) { + return graph.get(level).get(node); + } + int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node); + assert nodeIndex >= 0; + return graph.get(level).get(nodeIndex); + } + + @Override + public int size() { + return graph.get(0).size(); // all nodes are located on the 0th level + } + + /** + * Add node on the given level + * + * @param level level to add a node on + * @param node the node to add, represented as an ordinal on the level 0. + */ + public void addNode(int level, int node) { + if (level > 0) { + // if the new node introduces a new level, add more levels to the graph, + // and make this node the graph's new entry point + if (level >= numLevels) { + for (int i = numLevels; i <= level; i++) { + graph.add(new ArrayList<>()); + nodesByLevel.add(new int[] {node}); + } + numLevels = level + 1; + entryNode = node; + } else { + // Add this node id to this level's nodes + int[] nodes = nodesByLevel.get(level); + int idx = graph.get(level).size(); + if (idx < nodes.length) { + nodes[idx] = node; + } else { + nodes = ArrayUtil.grow(nodes); + nodes[idx] = node; + nodesByLevel.set(level, nodes); + } + } + } + + graph.get(level).add(new NeighborArray(maxConn + 1)); + } + + @Override + public void seek(int level, int targetNode) { + cur = getNeighbors(level, targetNode); + upto = -1; + } + + @Override + public int nextNeighbor() { + if (++upto < cur.size()) { + return cur.node[upto]; + } + return NO_MORE_DOCS; + } + + /** + * Returns the current number of levels in the graph + * + * @return the current number of levels in the graph + */ + @Override + public int numLevels() { + return numLevels; + } + + /** + * Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th + * level + * + * @return the graph's current entry node on the top level + */ + @Override + public int entryNode() { + return entryNode; + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + if (level == 0) { + return new NodesIterator(size()); + } else { + return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 05900243cc3..640b1b6d57b 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -40,7 +40,6 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.StringField; -import org.apache.lucene.index.KnnGraphValues.NodesIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.ScoreDoc; @@ -54,6 +53,8 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.junit.After; import org.junit.Before; @@ -239,7 +240,7 @@ public class TestKnnGraph extends LuceneTestCase { ((CodecReader) getOnlyLeafReader(reader)).getVectorReader(); Lucene91HnswVectorsReader vectorReader = (Lucene91HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD); - graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD)); + graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD)); } } return graph; @@ -259,7 +260,7 @@ public class TestKnnGraph extends LuceneTestCase { return values; } - int[][][] copyGraph(KnnGraphValues graphValues) throws IOException { + int[][][] copyGraph(HnswGraph graphValues) throws IOException { int[][][] graph = new int[graphValues.numLevels()][][]; int size = graphValues.size(); int[] scratch = new int[maxConn]; @@ -439,7 +440,7 @@ public class TestKnnGraph extends LuceneTestCase { if (vectorReader == null) { continue; } - KnnGraphValues graphValues = vectorReader.getGraphValues(vectorField); + HnswGraph graphValues = vectorReader.getGraph(vectorField); VectorValues vectorValues = reader.getVectorValues(vectorField); if (vectorValues == null) { assert graphValues == null; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java index be600ae40c5..3a30724d527 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -50,7 +50,6 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomAccessVectorValues; @@ -252,8 +251,7 @@ public class KnnGraphTester { KnnVectorsReader vectorsReader = ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader()) .getFieldReader(KNN_FIELD); - KnnGraphValues knnValues = - ((Lucene91HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD); + HnswGraph knnValues = ((Lucene91HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD); System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc()); printGraphFanout(knnValues, leafReader.maxDoc()); } @@ -274,7 +272,7 @@ public class KnnGraphTester { } } - private void dumpGraph(HnswGraph hnsw) { + private void dumpGraph(OnHeapHnswGraph hnsw) { for (int i = 0; i < hnsw.size(); i++) { NeighborArray neighbors = hnsw.getNeighbors(0, i); System.out.printf(Locale.ROOT, "%5d", i); @@ -303,7 +301,7 @@ public class KnnGraphTester { } @SuppressForbidden(reason = "Prints stuff") - private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException { + private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException { int min = Integer.MAX_VALUE, max = 0, total = 0; int count = 0; int[] leafHist = new int[numDocs]; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index 5142b78b9a7..ff1e3e22167 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -37,8 +37,6 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.KnnGraphValues; -import org.apache.lucene.index.KnnGraphValues.NodesIterator; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; @@ -51,6 +49,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; /** Tests HNSW KNN graphs */ public class TestHnswGraph extends LuceneTestCase { @@ -110,19 +109,19 @@ public class TestHnswGraph extends LuceneTestCase { assertEquals(indexedDoc, ctx.reader().maxDoc()); assertEquals(indexedDoc, ctx.reader().numDocs()); assertVectorsEqual(v3, values); - KnnGraphValues graphValues = + HnswGraph graphValues = ((Lucene91HnswVectorsReader) ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) ctx.reader()).getVectorReader()) .getFieldReader("field")) - .getGraphValues("field"); + .getGraph("field"); assertGraphEqual(hnsw, graphValues); } } } } - private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h) throws IOException { + private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); assertEquals("the number of nodes in the graphs are different!", g.size(), h.size()); @@ -159,7 +158,7 @@ public class TestHnswGraph extends LuceneTestCase { HnswGraphBuilder builder = new HnswGraphBuilder( vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt()); - HnswGraph hnsw = builder.build(vectors); + OnHeapHnswGraph hnsw = builder.build(vectors); // run some searches NeighborQueue nn = HnswGraphSearcher.search( @@ -197,7 +196,7 @@ public class TestHnswGraph extends LuceneTestCase { HnswGraphBuilder builder = new HnswGraphBuilder( vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt()); - HnswGraph hnsw = builder.build(vectors); + OnHeapHnswGraph hnsw = builder.build(vectors); // the first 10 docs must not be deleted to ensure the expected recall Bits acceptOrds = createRandomAcceptOrds(10, vectors.size); NeighborQueue nn = @@ -226,7 +225,7 @@ public class TestHnswGraph extends LuceneTestCase { HnswGraphBuilder builder = new HnswGraphBuilder( vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt()); - HnswGraph hnsw = builder.build(vectors); + OnHeapHnswGraph hnsw = builder.build(vectors); // Skip over half of the documents that are closest to the query vector FixedBitSet acceptOrds = new FixedBitSet(nDoc); @@ -354,7 +353,7 @@ public class TestHnswGraph extends LuceneTestCase { assertLevel0Neighbors(builder.hnsw, 5, 1, 4); } - private void assertLevel0Neighbors(HnswGraph graph, int node, int... expected) { + private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) { Arrays.sort(expected); NeighborArray nn = graph.getNeighbors(0, node); int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size()); @@ -376,7 +375,7 @@ public class TestHnswGraph extends LuceneTestCase { int topK = 5; HnswGraphBuilder builder = new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong()); - HnswGraph hnsw = builder.build(vectors); + OnHeapHnswGraph hnsw = builder.build(vectors); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); int totalMatches = 0; @@ -505,7 +504,7 @@ public class TestHnswGraph extends LuceneTestCase { return value; } - private Set getNeighborNodes(KnnGraphValues g) throws IOException { + private Set getNeighborNodes(HnswGraph g) throws IOException { Set neighbors = new HashSet<>(); for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) { neighbors.add(n);