Rename KnnGraphValues -> HnswGraph (#645)

This PR proposes some renames to clarify the code structure. The top-level
`KnnGraphValues` is renamed to `HnswGraph`, since it now represents a
hierarchical graph. It's also moved from `org.apache.lucene.index` to the
`hnsw` package.

Other renames:
* The old `HnswGraph` -> `OnHeapHnswGraph`
* `IndexedKnnGraphValues` -> `OffHeapHnswGraph` (to match
`OffHeapVectorValues`)
This commit is contained in:
Julie Tibshirani 2022-02-07 13:21:15 -08:00 committed by GitHub
parent e7546c2427
commit eb5bdd7d15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 341 additions and 367 deletions

View File

@ -30,7 +30,7 @@ import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
* Builder for HNSW graph. See {@link Lucene90HnswGraph} for a gloss on the algorithm and the
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
* meaning of the hyperparameters.
*
* <p>This class is preserved here only for tests.
@ -53,7 +53,7 @@ public final class Lucene90HnswGraphBuilder {
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
private final BoundsChecker bound;
final Lucene90HnswGraph hnsw;
final Lucene90OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault();
@ -90,7 +90,7 @@ public final class Lucene90HnswGraphBuilder {
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.hnsw = new Lucene90HnswGraph(maxConn);
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
bound = BoundsChecker.create(similarityFunction.reversed);
random = new SplittableRandom(seed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
@ -104,7 +104,7 @@ public final class Lucene90HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public Lucene90HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -143,7 +143,7 @@ public final class Lucene90HnswGraphBuilder {
void addGraphNode(float[] value) throws IOException {
// We pass 'null' for acceptOrds because there are no deletions while building the graph
NeighborQueue candidates =
Lucene90HnswGraph.search(
Lucene90OnHeapHnswGraph.search(
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
int node = hnsw.addNode();

View File

@ -31,7 +31,6 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
@ -47,6 +46,7 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
@ -243,7 +243,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
// use a seed that is fixed for the index so we get reproducible results for the same query
final SplittableRandom random = new SplittableRandom(checksumSeed);
NeighborQueue results =
Lucene90HnswGraph.search(
Lucene90OnHeapHnswGraph.search(
target,
k,
k,
@ -291,7 +291,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
/** Get knn graph values; used for testing */
public KnnGraphValues getGraphValues(String field) throws IOException {
public HnswGraph getGraphValues(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
@ -300,14 +300,14 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
if (entry != null && entry.indexDataLength > 0) {
return getGraphValues(entry);
} else {
return KnnGraphValues.EMPTY;
return HnswGraph.EMPTY;
}
}
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
private HnswGraph getGraphValues(FieldEntry entry) throws IOException {
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
return new IndexedKnnGraphReader(entry, bytesSlice);
return new OffHeapHnswGraph(entry, bytesSlice);
}
@Override
@ -465,7 +465,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
/** Read the nearest-neighbors graph from the index input */
private static final class IndexedKnnGraphReader extends KnnGraphValues {
private static final class OffHeapHnswGraph extends HnswGraph {
final FieldEntry entry;
final IndexInput dataIn;
@ -474,7 +474,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
int arcUpTo;
int arc;
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) {
this.entry = entry;
this.dataIn = dataIn;
}

View File

@ -23,42 +23,20 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.SplittableRandom;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
* Navigable Small-world graph. Provides efficient approximate nearest neighbor search for high
* dimensional vectors. See <a href="https://doi.org/10.1016/j.is.2013.10.006">Approximate nearest
* neighbor algorithm based on navigable small world graphs [2014]</a> and <a
* href="https://arxiv.org/abs/1603.09320">this paper [2018]</a> for details.
*
* <p>The nomenclature is a bit different here from what's used in those papers:
*
* <h2>Hyperparameters</h2>
*
* <ul>
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the
* number of random entry points to sample.
* <li><code>beamWidth</code> in {@link Lucene90HnswGraphBuilder} has the same meaning as <code>
* efConst </code> in the 2018 paper. It is the number of nearest neighbor candidates to track
* while searching the graph for each newly inserted node.
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
* how many of the <code>efConst</code> neighbors are connected to the new node
* </ul>
*
* <p>Note: The graph may be searched by multiple threads concurrently, but updates are not
* thread-safe. Also note: there is no notion of deletions. Document searching built on top of this
* must do its own deletion-filtering.
*
* <p>Graph building logic is preserved here only for tests.
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
* construct the HNSW graph before it's written to the index.
*/
public final class Lucene90HnswGraph extends KnnGraphValues {
public final class Lucene90OnHeapHnswGraph extends HnswGraph {
private final int maxConn;
@ -71,7 +49,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues {
private int upto;
private NeighborArray cur;
Lucene90HnswGraph(int maxConn) {
Lucene90OnHeapHnswGraph(int maxConn) {
graph = new ArrayList<>();
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
@ -100,7 +78,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues {
int numSeed,
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
HnswGraph graphValues,
Bits acceptOrds,
SplittableRandom random)
throws IOException {

View File

@ -241,7 +241,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
beamWidth,
Lucene90HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
Lucene90HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
for (int ord = 0; ord < offsets.length; ord++) {
// write graph

View File

@ -30,7 +30,6 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
@ -46,6 +45,7 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
@ -235,7 +235,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
k,
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
int i = 0;
@ -277,23 +277,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
/** Get knn graph values; used for testing */
public KnnGraphValues getGraphValues(String field) throws IOException {
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
return getGraphValues(entry);
return getGraph(entry);
} else {
return KnnGraphValues.EMPTY;
return HnswGraph.EMPTY;
}
}
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
private HnswGraph getGraph(FieldEntry entry) throws IOException {
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
return new IndexedKnnGraphReader(entry, bytesSlice);
return new OffHeapHnswGraph(entry, bytesSlice);
}
@Override
@ -478,7 +478,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
/** Read the nearest-neighbors graph from the index input */
private static final class IndexedKnnGraphReader extends KnnGraphValues {
private static final class OffHeapHnswGraph extends HnswGraph {
final IndexInput dataIn;
final int[][] nodesByLevel;
@ -492,7 +492,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
int arcUpTo;
int arc;
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) {
this.dataIn = dataIn;
this.nodesByLevel = entry.nodesByLevel;
this.numLevels = entry.numLevels;

View File

@ -26,7 +26,6 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -36,9 +35,10 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
/**
* Writes vector values and knn graphs to index segments.
@ -141,7 +141,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
new Lucene91HnswVectorsReader.OffHeapVectorValues(
vectors.dimension(), docIds, vectorDataInput);
HnswGraph graph =
OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
@ -197,7 +197,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
long vectorIndexOffset,
long vectorIndexLength,
int[] docIds,
HnswGraph graph)
OnHeapHnswGraph graph)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
@ -232,7 +232,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
}
}
private HnswGraph writeGraph(
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
@ -241,7 +241,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
new HnswGraphBuilder(
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();

View File

@ -1,151 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.NoSuchElementException;
import java.util.PrimitiveIterator;
/**
* Access to per-document neighbor lists in a (hierarchical) knn search graph.
*
* @lucene.experimental
*/
public abstract class KnnGraphValues {
/** Sole constructor */
protected KnnGraphValues() {}
/**
* Move the pointer to exactly the given {@code level}'s {@code target}. After this method
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
*
* @param level level of the graph
* @param target ordinal of a node in the graph, must be &ge; 0 and &lt; {@link
* VectorValues#size()}.
*/
public abstract void seek(int level, int target) throws IOException;
/** Returns the number of nodes in the graph */
public abstract int size();
/**
* Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
*
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
*/
public abstract int nextNeighbor() throws IOException;
/** Returns the number of levels of the graph */
public abstract int numLevels() throws IOException;
/** Returns graph's entry point on the top level * */
public abstract int entryNode() throws IOException;
/**
* Get all nodes on a given level as node 0th ordinals
*
* @param level level for which to get all nodes
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
*/
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
/** Empty graph value */
public static KnnGraphValues EMPTY =
new KnnGraphValues() {
@Override
public int nextNeighbor() {
return NO_MORE_DOCS;
}
@Override
public void seek(int level, int target) {}
@Override
public int size() {
return 0;
}
@Override
public int numLevels() {
return 0;
}
@Override
public int entryNode() {
return 0;
}
@Override
public NodesIterator getNodesOnLevel(int level) {
return NodesIterator.EMPTY;
}
};
/**
* Iterator over the graph nodes on a certain level, Iterator also provides the size the total
* number of nodes to be iterated over.
*/
public static final class NodesIterator implements PrimitiveIterator.OfInt {
static NodesIterator EMPTY = new NodesIterator(0);
private final int[] nodes;
private final int size;
int cur = 0;
/** Constructor for iterator based on the nodes array up to the size */
public NodesIterator(int[] nodes, int size) {
assert nodes != null;
assert size <= nodes.length;
this.nodes = nodes;
this.size = size;
}
/** Constructor for iterator based on the size */
public NodesIterator(int size) {
this.nodes = null;
this.size = size;
}
@Override
public int nextInt() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
if (nodes == null) {
return cur++;
} else {
return nodes[cur++];
}
}
@Override
public boolean hasNext() {
return cur < size;
}
/** The number of elements in this iterator * */
public int size() {
return size;
}
}
}

View File

@ -19,11 +19,10 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.util.ArrayUtil;
import java.io.IOException;
import java.util.NoSuchElementException;
import java.util.PrimitiveIterator;
import org.apache.lucene.index.VectorValues;
/**
* Hierarchical Navigable Small World graph. Provides efficient approximate nearest neighbor search
@ -47,142 +46,124 @@ import org.apache.lucene.util.ArrayUtil;
* thread-safe. The search method optionally takes a set of "accepted nodes", which can be used to
* exclude deleted documents.
*/
public final class HnswGraph extends KnnGraphValues {
public abstract class HnswGraph {
private final int maxConn;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
// Nodes by level expressed as the level 0's nodes' ordinals.
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
private final List<int[]> nodesByLevel;
// graph is a list of graph levels.
// Each level is represented as List<NeighborArray> nodes' connections on this level.
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
private final List<List<NeighborArray>> graph;
// KnnGraphValues iterator members
private int upto;
private NeighborArray cur;
HnswGraph(int maxConn, int levelOfFirstNode) {
this.maxConn = maxConn;
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
for (int i = 0; i < numLevels; i++) {
graph.add(new ArrayList<>());
// Typically with diversity criteria we see nodes not fully occupied;
// average fanout seems to be about 1/2 maxConn.
// There is some indexing time penalty for under-allocating, but saves RAM
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
}
this.nodesByLevel = new ArrayList<>(numLevels);
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
for (int l = 1; l < numLevels; l++) {
nodesByLevel.add(new int[] {0});
}
}
/** Sole constructor */
protected HnswGraph() {}
/**
* Returns the {@link NeighborQueue} connected to the given node.
* Move the pointer to exactly the given {@code level}'s {@code target}. After this method
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
*
* @param level level of the graph
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
* @param target ordinal of a node in the graph, must be &ge; 0 and &lt; {@link
* VectorValues#size()}.
*/
public NeighborArray getNeighbors(int level, int node) {
if (level == 0) {
return graph.get(level).get(node);
}
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
assert nodeIndex >= 0;
return graph.get(level).get(nodeIndex);
}
public abstract void seek(int level, int target) throws IOException;
@Override
public int size() {
return graph.get(0).size(); // all nodes are located on the 0th level
}
/** Returns the number of nodes in the graph */
public abstract int size();
/**
* Add node on the given level
* Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
*
* @param level level to add a node on
* @param node the node to add, represented as an ordinal on the level 0.
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
*/
public void addNode(int level, int node) {
if (level > 0) {
// if the new node introduces a new level, add more levels to the graph,
// and make this node the graph's new entry point
if (level >= numLevels) {
for (int i = numLevels; i <= level; i++) {
graph.add(new ArrayList<>());
nodesByLevel.add(new int[] {node});
}
numLevels = level + 1;
entryNode = node;
} else {
// Add this node id to this level's nodes
int[] nodes = nodesByLevel.get(level);
int idx = graph.get(level).size();
if (idx < nodes.length) {
nodes[idx] = node;
} else {
nodes = ArrayUtil.grow(nodes);
nodes[idx] = node;
nodesByLevel.set(level, nodes);
}
}
}
public abstract int nextNeighbor() throws IOException;
graph.get(level).add(new NeighborArray(maxConn + 1));
}
/** Returns the number of levels of the graph */
public abstract int numLevels() throws IOException;
@Override
public void seek(int level, int targetNode) {
cur = getNeighbors(level, targetNode);
upto = -1;
}
/** Returns graph's entry point on the top level * */
public abstract int entryNode() throws IOException;
/**
* Get all nodes on a given level as node 0th ordinals
*
* @param level level for which to get all nodes
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
*/
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
/** Empty graph value */
public static HnswGraph EMPTY =
new HnswGraph() {
@Override
public int nextNeighbor() {
if (++upto < cur.size()) {
return cur.node[upto];
}
return NO_MORE_DOCS;
}
/**
* Returns the current number of levels in the graph
*
* @return the current number of levels in the graph
*/
@Override
public int numLevels() {
return numLevels;
public void seek(int level, int target) {}
@Override
public int size() {
return 0;
}
@Override
public int numLevels() {
return 0;
}
/**
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
* level
*
* @return the graph's current entry node on the top level
*/
@Override
public int entryNode() {
return entryNode;
return 0;
}
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return NodesIterator.EMPTY;
}
};
/**
* Iterator over the graph nodes on a certain level, Iterator also provides the size the total
* number of nodes to be iterated over.
*/
public static final class NodesIterator implements PrimitiveIterator.OfInt {
static NodesIterator EMPTY = new NodesIterator(0);
private final int[] nodes;
private final int size;
int cur = 0;
/** Constructor for iterator based on the nodes array up to the size */
public NodesIterator(int[] nodes, int size) {
assert nodes != null;
assert size <= nodes.length;
this.nodes = nodes;
this.size = size;
}
/** Constructor for iterator based on the size */
public NodesIterator(int size) {
this.nodes = null;
this.size = size;
}
@Override
public int nextInt() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
if (nodes == null) {
return cur++;
} else {
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
return nodes[cur++];
}
}
@Override
public boolean hasNext() {
return cur < size;
}
/** The number of elements in this iterator * */
public int size() {
return size;
}
}
}

View File

@ -54,7 +54,7 @@ public final class HnswGraphBuilder {
private final BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
final HnswGraph hnsw;
final OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault();
@ -95,7 +95,7 @@ public final class HnswGraphBuilder {
this.ml = 1 / Math.log(1.0 * maxConn);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new HnswGraph(maxConn, levelOfFirstNode);
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
@ -113,7 +113,7 @@ public final class HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");

View File

@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;
@ -62,8 +61,8 @@ public final class HnswGraphSearcher {
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graphValues the graph values. May represent the entire graph, or a level in a
* hierarchical graph.
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @return a priority queue holding the closest neighbors found
@ -73,7 +72,7 @@ public final class HnswGraphSearcher {
int topK,
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
HnswGraph graph,
Bits acceptOrds)
throws IOException {
HnswGraphSearcher graphSearcher =
@ -82,12 +81,12 @@ public final class HnswGraphSearcher {
new NeighborQueue(topK, similarityFunction.reversed == false),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graphValues.entryNode()};
for (int level = graphValues.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graphValues, null);
int[] eps = new int[] {graph.entryNode()};
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null);
eps[0] = results.pop();
}
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graphValues, acceptOrds);
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds);
return results;
}
@ -99,7 +98,7 @@ public final class HnswGraphSearcher {
* @param level level to search
* @param eps the entry points for search at this level expressed as level 0th ordinals
* @param vectors vector values
* @param graphValues the graph values
* @param graph the graph values
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @return a priority queue holding the closest neighbors found
@ -110,10 +109,10 @@ public final class HnswGraphSearcher {
int level,
final int[] eps,
RandomAccessVectorValues vectors,
KnnGraphValues graphValues,
HnswGraph graph,
Bits acceptOrds)
throws IOException {
int size = graphValues.size();
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
clearScratchState();
@ -140,9 +139,9 @@ public final class HnswGraphSearcher {
break;
}
int topCandidateNode = candidates.pop();
graphValues.seek(level, topCandidateNode);
graph.seek(level, topCandidateNode);
int friendOrd;
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;

View File

@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
* construct the HNSW graph before it's written to the index.
*/
public final class OnHeapHnswGraph extends HnswGraph {
private final int maxConn;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
// Nodes by level expressed as the level 0's nodes' ordinals.
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
private final List<int[]> nodesByLevel;
// graph is a list of graph levels.
// Each level is represented as List<NeighborArray> nodes' connections on this level.
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
private final List<List<NeighborArray>> graph;
// KnnGraphValues iterator members
private int upto;
private NeighborArray cur;
OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
this.maxConn = maxConn;
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
for (int i = 0; i < numLevels; i++) {
graph.add(new ArrayList<>());
// Typically with diversity criteria we see nodes not fully occupied;
// average fanout seems to be about 1/2 maxConn.
// There is some indexing time penalty for under-allocating, but saves RAM
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
}
this.nodesByLevel = new ArrayList<>(numLevels);
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
for (int l = 1; l < numLevels; l++) {
nodesByLevel.add(new int[] {0});
}
}
/**
* Returns the {@link NeighborQueue} connected to the given node.
*
* @param level level of the graph
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
*/
public NeighborArray getNeighbors(int level, int node) {
if (level == 0) {
return graph.get(level).get(node);
}
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
assert nodeIndex >= 0;
return graph.get(level).get(nodeIndex);
}
@Override
public int size() {
return graph.get(0).size(); // all nodes are located on the 0th level
}
/**
* Add node on the given level
*
* @param level level to add a node on
* @param node the node to add, represented as an ordinal on the level 0.
*/
public void addNode(int level, int node) {
if (level > 0) {
// if the new node introduces a new level, add more levels to the graph,
// and make this node the graph's new entry point
if (level >= numLevels) {
for (int i = numLevels; i <= level; i++) {
graph.add(new ArrayList<>());
nodesByLevel.add(new int[] {node});
}
numLevels = level + 1;
entryNode = node;
} else {
// Add this node id to this level's nodes
int[] nodes = nodesByLevel.get(level);
int idx = graph.get(level).size();
if (idx < nodes.length) {
nodes[idx] = node;
} else {
nodes = ArrayUtil.grow(nodes);
nodes[idx] = node;
nodesByLevel.set(level, nodes);
}
}
}
graph.get(level).add(new NeighborArray(maxConn + 1));
}
@Override
public void seek(int level, int targetNode) {
cur = getNeighbors(level, targetNode);
upto = -1;
}
@Override
public int nextNeighbor() {
if (++upto < cur.size()) {
return cur.node[upto];
}
return NO_MORE_DOCS;
}
/**
* Returns the current number of levels in the graph
*
* @return the current number of levels in the graph
*/
@Override
public int numLevels() {
return numLevels;
}
/**
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
* level
*
* @return the graph's current entry node on the top level
*/
@Override
public int entryNode() {
return entryNode;
}
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
} else {
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
}
}
}

View File

@ -40,7 +40,6 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.ScoreDoc;
@ -54,6 +53,8 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;
@ -239,7 +240,7 @@ public class TestKnnGraph extends LuceneTestCase {
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
Lucene91HnswVectorsReader vectorReader =
(Lucene91HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD));
graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD));
}
}
return graph;
@ -259,7 +260,7 @@ public class TestKnnGraph extends LuceneTestCase {
return values;
}
int[][][] copyGraph(KnnGraphValues graphValues) throws IOException {
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
int[][][] graph = new int[graphValues.numLevels()][][];
int size = graphValues.size();
int[] scratch = new int[maxConn];
@ -439,7 +440,7 @@ public class TestKnnGraph extends LuceneTestCase {
if (vectorReader == null) {
continue;
}
KnnGraphValues graphValues = vectorReader.getGraphValues(vectorField);
HnswGraph graphValues = vectorReader.getGraph(vectorField);
VectorValues vectorValues = reader.getVectorValues(vectorField);
if (vectorValues == null) {
assert graphValues == null;

View File

@ -50,7 +50,6 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
@ -252,8 +251,7 @@ public class KnnGraphTester {
KnnVectorsReader vectorsReader =
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
.getFieldReader(KNN_FIELD);
KnnGraphValues knnValues =
((Lucene91HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD);
HnswGraph knnValues = ((Lucene91HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD);
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
printGraphFanout(knnValues, leafReader.maxDoc());
}
@ -274,7 +272,7 @@ public class KnnGraphTester {
}
}
private void dumpGraph(HnswGraph hnsw) {
private void dumpGraph(OnHeapHnswGraph hnsw) {
for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(0, i);
System.out.printf(Locale.ROOT, "%5d", i);
@ -303,7 +301,7 @@ public class KnnGraphTester {
}
@SuppressForbidden(reason = "Prints stuff")
private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException {
private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException {
int min = Integer.MAX_VALUE, max = 0, total = 0;
int count = 0;
int[] leafHist = new int[numDocs];

View File

@ -37,8 +37,6 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
@ -51,6 +49,7 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
/** Tests HNSW KNN graphs */
public class TestHnswGraph extends LuceneTestCase {
@ -110,19 +109,19 @@ public class TestHnswGraph extends LuceneTestCase {
assertEquals(indexedDoc, ctx.reader().maxDoc());
assertEquals(indexedDoc, ctx.reader().numDocs());
assertVectorsEqual(v3, values);
KnnGraphValues graphValues =
HnswGraph graphValues =
((Lucene91HnswVectorsReader)
((PerFieldKnnVectorsFormat.FieldsReader)
((CodecReader) ctx.reader()).getVectorReader())
.getFieldReader("field"))
.getGraphValues("field");
.getGraph("field");
assertGraphEqual(hnsw, graphValues);
}
}
}
}
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h) throws IOException {
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
@ -159,7 +158,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors);
// run some searches
NeighborQueue nn =
HnswGraphSearcher.search(
@ -197,7 +196,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors);
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn =
@ -226,7 +225,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors);
// Skip over half of the documents that are closest to the query vector
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
@ -354,7 +353,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}
private void assertLevel0Neighbors(HnswGraph graph, int node, int... expected) {
private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(0, node);
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
@ -376,7 +375,7 @@ public class TestHnswGraph extends LuceneTestCase {
int topK = 5;
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0;
@ -505,7 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
return value;
}
private Set<Integer> getNeighborNodes(KnnGraphValues g) throws IOException {
private Set<Integer> getNeighborNodes(HnswGraph g) throws IOException {
Set<Integer> neighbors = new HashSet<>();
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
neighbors.add(n);