mirror of https://github.com/apache/lucene.git
Rename KnnGraphValues -> HnswGraph (#645)
This PR proposes some renames to clarify the code structure. The top-level `KnnGraphValues` is renamed to `HnswGraph`, since it now represents a hierarchical graph. It's also moved from `org.apache.lucene.index` to the `hnsw` package. Other renames: * The old `HnswGraph` -> `OnHeapHnswGraph` * `IndexedKnnGraphValues` -> `OffHeapHnswGraph` (to match `OffHeapVectorValues`)
This commit is contained in:
parent
e7546c2427
commit
eb5bdd7d15
|
@ -30,7 +30,7 @@ import org.apache.lucene.util.hnsw.NeighborArray;
|
|||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
* Builder for HNSW graph. See {@link Lucene90HnswGraph} for a gloss on the algorithm and the
|
||||
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
|
||||
* meaning of the hyperparameters.
|
||||
*
|
||||
* <p>This class is preserved here only for tests.
|
||||
|
@ -53,7 +53,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final BoundsChecker bound;
|
||||
final Lucene90HnswGraph hnsw;
|
||||
final Lucene90OnHeapHnswGraph hnsw;
|
||||
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
||||
|
@ -90,7 +90,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
}
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
this.hnsw = new Lucene90HnswGraph(maxConn);
|
||||
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
random = new SplittableRandom(seed);
|
||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||
|
@ -104,7 +104,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public Lucene90HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
@ -143,7 +143,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
void addGraphNode(float[] value) throws IOException {
|
||||
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
||||
NeighborQueue candidates =
|
||||
Lucene90HnswGraph.search(
|
||||
Lucene90OnHeapHnswGraph.search(
|
||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
||||
|
||||
int node = hnsw.addNode();
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.lucene.index.CorruptIndexException;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
|
@ -47,6 +46,7 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
|
@ -243,7 +243,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
// use a seed that is fixed for the index so we get reproducible results for the same query
|
||||
final SplittableRandom random = new SplittableRandom(checksumSeed);
|
||||
NeighborQueue results =
|
||||
Lucene90HnswGraph.search(
|
||||
Lucene90OnHeapHnswGraph.search(
|
||||
target,
|
||||
k,
|
||||
k,
|
||||
|
@ -291,7 +291,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public KnnGraphValues getGraphValues(String field) throws IOException {
|
||||
public HnswGraph getGraphValues(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
|
@ -300,14 +300,14 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
if (entry != null && entry.indexDataLength > 0) {
|
||||
return getGraphValues(entry);
|
||||
} else {
|
||||
return KnnGraphValues.EMPTY;
|
||||
return HnswGraph.EMPTY;
|
||||
}
|
||||
}
|
||||
|
||||
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
|
||||
private HnswGraph getGraphValues(FieldEntry entry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
|
||||
return new IndexedKnnGraphReader(entry, bytesSlice);
|
||||
return new OffHeapHnswGraph(entry, bytesSlice);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -465,7 +465,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
private static final class IndexedKnnGraphReader extends KnnGraphValues {
|
||||
private static final class OffHeapHnswGraph extends HnswGraph {
|
||||
|
||||
final FieldEntry entry;
|
||||
final IndexInput dataIn;
|
||||
|
@ -474,7 +474,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
int arcUpTo;
|
||||
int arc;
|
||||
|
||||
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
|
||||
OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) {
|
||||
this.entry = entry;
|
||||
this.dataIn = dataIn;
|
||||
}
|
||||
|
|
|
@ -23,42 +23,20 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
* Navigable Small-world graph. Provides efficient approximate nearest neighbor search for high
|
||||
* dimensional vectors. See <a href="https://doi.org/10.1016/j.is.2013.10.006">Approximate nearest
|
||||
* neighbor algorithm based on navigable small world graphs [2014]</a> and <a
|
||||
* href="https://arxiv.org/abs/1603.09320">this paper [2018]</a> for details.
|
||||
*
|
||||
* <p>The nomenclature is a bit different here from what's used in those papers:
|
||||
*
|
||||
* <h2>Hyperparameters</h2>
|
||||
*
|
||||
* <ul>
|
||||
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the
|
||||
* number of random entry points to sample.
|
||||
* <li><code>beamWidth</code> in {@link Lucene90HnswGraphBuilder} has the same meaning as <code>
|
||||
* efConst </code> in the 2018 paper. It is the number of nearest neighbor candidates to track
|
||||
* while searching the graph for each newly inserted node.
|
||||
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
|
||||
* how many of the <code>efConst</code> neighbors are connected to the new node
|
||||
* </ul>
|
||||
*
|
||||
* <p>Note: The graph may be searched by multiple threads concurrently, but updates are not
|
||||
* thread-safe. Also note: there is no notion of deletions. Document searching built on top of this
|
||||
* must do its own deletion-filtering.
|
||||
*
|
||||
* <p>Graph building logic is preserved here only for tests.
|
||||
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||
* construct the HNSW graph before it's written to the index.
|
||||
*/
|
||||
public final class Lucene90HnswGraph extends KnnGraphValues {
|
||||
public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||
|
||||
private final int maxConn;
|
||||
|
||||
|
@ -71,7 +49,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues {
|
|||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
Lucene90HnswGraph(int maxConn) {
|
||||
Lucene90OnHeapHnswGraph(int maxConn) {
|
||||
graph = new ArrayList<>();
|
||||
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
|
||||
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
|
||||
|
@ -100,7 +78,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues {
|
|||
int numSeed,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
KnnGraphValues graphValues,
|
||||
HnswGraph graphValues,
|
||||
Bits acceptOrds,
|
||||
SplittableRandom random)
|
||||
throws IOException {
|
|
@ -241,7 +241,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
|||
beamWidth,
|
||||
Lucene90HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
Lucene90HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
for (int ord = 0; ord < offsets.length; ord++) {
|
||||
// write graph
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.index.CorruptIndexException;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
|
@ -46,6 +45,7 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
|
@ -235,7 +235,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
k,
|
||||
vectorValues,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraphValues(fieldEntry),
|
||||
getGraph(fieldEntry),
|
||||
getAcceptOrds(acceptDocs, fieldEntry));
|
||||
|
||||
int i = 0;
|
||||
|
@ -277,23 +277,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public KnnGraphValues getGraphValues(String field) throws IOException {
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
}
|
||||
FieldEntry entry = fields.get(field);
|
||||
if (entry != null && entry.vectorIndexLength > 0) {
|
||||
return getGraphValues(entry);
|
||||
return getGraph(entry);
|
||||
} else {
|
||||
return KnnGraphValues.EMPTY;
|
||||
return HnswGraph.EMPTY;
|
||||
}
|
||||
}
|
||||
|
||||
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
|
||||
private HnswGraph getGraph(FieldEntry entry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
|
||||
return new IndexedKnnGraphReader(entry, bytesSlice);
|
||||
return new OffHeapHnswGraph(entry, bytesSlice);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -478,7 +478,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
private static final class IndexedKnnGraphReader extends KnnGraphValues {
|
||||
private static final class OffHeapHnswGraph extends HnswGraph {
|
||||
|
||||
final IndexInput dataIn;
|
||||
final int[][] nodesByLevel;
|
||||
|
@ -492,7 +492,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
int arcUpTo;
|
||||
int arc;
|
||||
|
||||
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
|
||||
OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) {
|
||||
this.dataIn = dataIn;
|
||||
this.nodesByLevel = entry.nodesByLevel;
|
||||
this.numLevels = entry.numLevels;
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -36,9 +35,10 @@ import org.apache.lucene.store.IndexOutput;
|
|||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
|
||||
/**
|
||||
* Writes vector values and knn graphs to index segments.
|
||||
|
@ -141,7 +141,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
|
||||
new Lucene91HnswVectorsReader.OffHeapVectorValues(
|
||||
vectors.dimension(), docIds, vectorDataInput);
|
||||
HnswGraph graph =
|
||||
OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
||||
|
@ -197,7 +197,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
long vectorIndexOffset,
|
||||
long vectorIndexLength,
|
||||
int[] docIds,
|
||||
HnswGraph graph)
|
||||
OnHeapHnswGraph graph)
|
||||
throws IOException {
|
||||
meta.writeInt(field.number);
|
||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||
|
@ -232,7 +232,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private HnswGraph writeGraph(
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
|
@ -241,7 +241,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
new HnswGraphBuilder(
|
||||
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
|
|
|
@ -1,151 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.PrimitiveIterator;
|
||||
|
||||
/**
|
||||
* Access to per-document neighbor lists in a (hierarchical) knn search graph.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class KnnGraphValues {
|
||||
|
||||
/** Sole constructor */
|
||||
protected KnnGraphValues() {}
|
||||
|
||||
/**
|
||||
* Move the pointer to exactly the given {@code level}'s {@code target}. After this method
|
||||
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
||||
*
|
||||
* @param level level of the graph
|
||||
* @param target ordinal of a node in the graph, must be ≥ 0 and < {@link
|
||||
* VectorValues#size()}.
|
||||
*/
|
||||
public abstract void seek(int level, int target) throws IOException;
|
||||
|
||||
/** Returns the number of nodes in the graph */
|
||||
public abstract int size();
|
||||
|
||||
/**
|
||||
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
||||
*
|
||||
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
||||
*/
|
||||
public abstract int nextNeighbor() throws IOException;
|
||||
|
||||
/** Returns the number of levels of the graph */
|
||||
public abstract int numLevels() throws IOException;
|
||||
|
||||
/** Returns graph's entry point on the top level * */
|
||||
public abstract int entryNode() throws IOException;
|
||||
|
||||
/**
|
||||
* Get all nodes on a given level as node 0th ordinals
|
||||
*
|
||||
* @param level level for which to get all nodes
|
||||
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
|
||||
*/
|
||||
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
|
||||
|
||||
/** Empty graph value */
|
||||
public static KnnGraphValues EMPTY =
|
||||
new KnnGraphValues() {
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int target) {}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
return NodesIterator.EMPTY;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
|
||||
* number of nodes to be iterated over.
|
||||
*/
|
||||
public static final class NodesIterator implements PrimitiveIterator.OfInt {
|
||||
static NodesIterator EMPTY = new NodesIterator(0);
|
||||
|
||||
private final int[] nodes;
|
||||
private final int size;
|
||||
int cur = 0;
|
||||
|
||||
/** Constructor for iterator based on the nodes array up to the size */
|
||||
public NodesIterator(int[] nodes, int size) {
|
||||
assert nodes != null;
|
||||
assert size <= nodes.length;
|
||||
this.nodes = nodes;
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
/** Constructor for iterator based on the size */
|
||||
public NodesIterator(int size) {
|
||||
this.nodes = null;
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
if (hasNext() == false) {
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
if (nodes == null) {
|
||||
return cur++;
|
||||
} else {
|
||||
return nodes[cur++];
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return cur < size;
|
||||
}
|
||||
|
||||
/** The number of elements in this iterator * */
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,11 +19,10 @@ package org.apache.lucene.util.hnsw;
|
|||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import java.io.IOException;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.PrimitiveIterator;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
||||
/**
|
||||
* Hierarchical Navigable Small World graph. Provides efficient approximate nearest neighbor search
|
||||
|
@ -47,142 +46,124 @@ import org.apache.lucene.util.ArrayUtil;
|
|||
* thread-safe. The search method optionally takes a set of "accepted nodes", which can be used to
|
||||
* exclude deleted documents.
|
||||
*/
|
||||
public final class HnswGraph extends KnnGraphValues {
|
||||
public abstract class HnswGraph {
|
||||
|
||||
private final int maxConn;
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
|
||||
// Nodes by level expressed as the level 0's nodes' ordinals.
|
||||
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
|
||||
private final List<int[]> nodesByLevel;
|
||||
|
||||
// graph is a list of graph levels.
|
||||
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<List<NeighborArray>> graph;
|
||||
|
||||
// KnnGraphValues iterator members
|
||||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
HnswGraph(int maxConn, int levelOfFirstNode) {
|
||||
this.maxConn = maxConn;
|
||||
this.numLevels = levelOfFirstNode + 1;
|
||||
this.graph = new ArrayList<>(numLevels);
|
||||
this.entryNode = 0;
|
||||
for (int i = 0; i < numLevels; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
// Typically with diversity criteria we see nodes not fully occupied;
|
||||
// average fanout seems to be about 1/2 maxConn.
|
||||
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
|
||||
}
|
||||
|
||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||
for (int l = 1; l < numLevels; l++) {
|
||||
nodesByLevel.add(new int[] {0});
|
||||
}
|
||||
}
|
||||
/** Sole constructor */
|
||||
protected HnswGraph() {}
|
||||
|
||||
/**
|
||||
* Returns the {@link NeighborQueue} connected to the given node.
|
||||
* Move the pointer to exactly the given {@code level}'s {@code target}. After this method
|
||||
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
||||
*
|
||||
* @param level level of the graph
|
||||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||
* @param target ordinal of a node in the graph, must be ≥ 0 and < {@link
|
||||
* VectorValues#size()}.
|
||||
*/
|
||||
public NeighborArray getNeighbors(int level, int node) {
|
||||
if (level == 0) {
|
||||
return graph.get(level).get(node);
|
||||
}
|
||||
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
|
||||
assert nodeIndex >= 0;
|
||||
return graph.get(level).get(nodeIndex);
|
||||
}
|
||||
public abstract void seek(int level, int target) throws IOException;
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graph.get(0).size(); // all nodes are located on the 0th level
|
||||
}
|
||||
/** Returns the number of nodes in the graph */
|
||||
public abstract int size();
|
||||
|
||||
/**
|
||||
* Add node on the given level
|
||||
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
||||
*
|
||||
* @param level level to add a node on
|
||||
* @param node the node to add, represented as an ordinal on the level 0.
|
||||
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
||||
*/
|
||||
public void addNode(int level, int node) {
|
||||
if (level > 0) {
|
||||
// if the new node introduces a new level, add more levels to the graph,
|
||||
// and make this node the graph's new entry point
|
||||
if (level >= numLevels) {
|
||||
for (int i = numLevels; i <= level; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
nodesByLevel.add(new int[] {node});
|
||||
public abstract int nextNeighbor() throws IOException;
|
||||
|
||||
/** Returns the number of levels of the graph */
|
||||
public abstract int numLevels() throws IOException;
|
||||
|
||||
/** Returns graph's entry point on the top level * */
|
||||
public abstract int entryNode() throws IOException;
|
||||
|
||||
/**
|
||||
* Get all nodes on a given level as node 0th ordinals
|
||||
*
|
||||
* @param level level for which to get all nodes
|
||||
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
|
||||
*/
|
||||
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
|
||||
|
||||
/** Empty graph value */
|
||||
public static HnswGraph EMPTY =
|
||||
new HnswGraph() {
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
|
||||
@Override
|
||||
public void seek(int level, int target) {}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
return NodesIterator.EMPTY;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
|
||||
* number of nodes to be iterated over.
|
||||
*/
|
||||
public static final class NodesIterator implements PrimitiveIterator.OfInt {
|
||||
static NodesIterator EMPTY = new NodesIterator(0);
|
||||
|
||||
private final int[] nodes;
|
||||
private final int size;
|
||||
int cur = 0;
|
||||
|
||||
/** Constructor for iterator based on the nodes array up to the size */
|
||||
public NodesIterator(int[] nodes, int size) {
|
||||
assert nodes != null;
|
||||
assert size <= nodes.length;
|
||||
this.nodes = nodes;
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
/** Constructor for iterator based on the size */
|
||||
public NodesIterator(int size) {
|
||||
this.nodes = null;
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
if (hasNext() == false) {
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
if (nodes == null) {
|
||||
return cur++;
|
||||
} else {
|
||||
// Add this node id to this level's nodes
|
||||
int[] nodes = nodesByLevel.get(level);
|
||||
int idx = graph.get(level).size();
|
||||
if (idx < nodes.length) {
|
||||
nodes[idx] = node;
|
||||
} else {
|
||||
nodes = ArrayUtil.grow(nodes);
|
||||
nodes[idx] = node;
|
||||
nodesByLevel.set(level, nodes);
|
||||
}
|
||||
return nodes[cur++];
|
||||
}
|
||||
}
|
||||
|
||||
graph.get(level).add(new NeighborArray(maxConn + 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int targetNode) {
|
||||
cur = getNeighbors(level, targetNode);
|
||||
upto = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
if (++upto < cur.size()) {
|
||||
return cur.node[upto];
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return cur < size;
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current number of levels in the graph
|
||||
*
|
||||
* @return the current number of levels in the graph
|
||||
*/
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return numLevels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
|
||||
* level
|
||||
*
|
||||
* @return the graph's current entry node on the top level
|
||||
*/
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return entryNode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
/** The number of elements in this iterator * */
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ public final class HnswGraphBuilder {
|
|||
private final BoundsChecker bound;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
|
||||
final HnswGraph hnsw;
|
||||
final OnHeapHnswGraph hnsw;
|
||||
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
||||
|
@ -95,7 +95,7 @@ public final class HnswGraphBuilder {
|
|||
this.ml = 1 / Math.log(1.0 * maxConn);
|
||||
this.random = new SplittableRandom(seed);
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new HnswGraph(maxConn, levelOfFirstNode);
|
||||
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
similarityFunction,
|
||||
|
@ -113,7 +113,7 @@ public final class HnswGraphBuilder {
|
|||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
|
@ -62,8 +61,8 @@ public final class HnswGraphSearcher {
|
|||
* @param topK the number of nodes to be returned
|
||||
* @param vectors the vector values
|
||||
* @param similarityFunction the similarity function to compare vectors
|
||||
* @param graphValues the graph values. May represent the entire graph, or a level in a
|
||||
* hierarchical graph.
|
||||
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
|
||||
* graph.
|
||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
|
@ -73,7 +72,7 @@ public final class HnswGraphSearcher {
|
|||
int topK,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
KnnGraphValues graphValues,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds)
|
||||
throws IOException {
|
||||
HnswGraphSearcher graphSearcher =
|
||||
|
@ -82,12 +81,12 @@ public final class HnswGraphSearcher {
|
|||
new NeighborQueue(topK, similarityFunction.reversed == false),
|
||||
new SparseFixedBitSet(vectors.size()));
|
||||
NeighborQueue results;
|
||||
int[] eps = new int[] {graphValues.entryNode()};
|
||||
for (int level = graphValues.numLevels() - 1; level >= 1; level--) {
|
||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graphValues, null);
|
||||
int[] eps = new int[] {graph.entryNode()};
|
||||
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null);
|
||||
eps[0] = results.pop();
|
||||
}
|
||||
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graphValues, acceptOrds);
|
||||
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
@ -99,7 +98,7 @@ public final class HnswGraphSearcher {
|
|||
* @param level level to search
|
||||
* @param eps the entry points for search at this level expressed as level 0th ordinals
|
||||
* @param vectors vector values
|
||||
* @param graphValues the graph values
|
||||
* @param graph the graph values
|
||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
|
@ -110,10 +109,10 @@ public final class HnswGraphSearcher {
|
|||
int level,
|
||||
final int[] eps,
|
||||
RandomAccessVectorValues vectors,
|
||||
KnnGraphValues graphValues,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds)
|
||||
throws IOException {
|
||||
int size = graphValues.size();
|
||||
int size = graph.size();
|
||||
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
clearScratchState();
|
||||
|
||||
|
@ -140,9 +139,9 @@ public final class HnswGraphSearcher {
|
|||
break;
|
||||
}
|
||||
int topCandidateNode = candidates.pop();
|
||||
graphValues.seek(level, topCandidateNode);
|
||||
graph.seek(level, topCandidateNode);
|
||||
int friendOrd;
|
||||
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
|
||||
if (visited.getAndSet(friendOrd)) {
|
||||
continue;
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
/**
|
||||
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||
* construct the HNSW graph before it's written to the index.
|
||||
*/
|
||||
public final class OnHeapHnswGraph extends HnswGraph {
|
||||
|
||||
private final int maxConn;
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
|
||||
// Nodes by level expressed as the level 0's nodes' ordinals.
|
||||
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
|
||||
private final List<int[]> nodesByLevel;
|
||||
|
||||
// graph is a list of graph levels.
|
||||
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<List<NeighborArray>> graph;
|
||||
|
||||
// KnnGraphValues iterator members
|
||||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
|
||||
this.maxConn = maxConn;
|
||||
this.numLevels = levelOfFirstNode + 1;
|
||||
this.graph = new ArrayList<>(numLevels);
|
||||
this.entryNode = 0;
|
||||
for (int i = 0; i < numLevels; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
// Typically with diversity criteria we see nodes not fully occupied;
|
||||
// average fanout seems to be about 1/2 maxConn.
|
||||
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
|
||||
}
|
||||
|
||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||
for (int l = 1; l < numLevels; l++) {
|
||||
nodesByLevel.add(new int[] {0});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link NeighborQueue} connected to the given node.
|
||||
*
|
||||
* @param level level of the graph
|
||||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public NeighborArray getNeighbors(int level, int node) {
|
||||
if (level == 0) {
|
||||
return graph.get(level).get(node);
|
||||
}
|
||||
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
|
||||
assert nodeIndex >= 0;
|
||||
return graph.get(level).get(nodeIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graph.get(0).size(); // all nodes are located on the 0th level
|
||||
}
|
||||
|
||||
/**
|
||||
* Add node on the given level
|
||||
*
|
||||
* @param level level to add a node on
|
||||
* @param node the node to add, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public void addNode(int level, int node) {
|
||||
if (level > 0) {
|
||||
// if the new node introduces a new level, add more levels to the graph,
|
||||
// and make this node the graph's new entry point
|
||||
if (level >= numLevels) {
|
||||
for (int i = numLevels; i <= level; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
nodesByLevel.add(new int[] {node});
|
||||
}
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
} else {
|
||||
// Add this node id to this level's nodes
|
||||
int[] nodes = nodesByLevel.get(level);
|
||||
int idx = graph.get(level).size();
|
||||
if (idx < nodes.length) {
|
||||
nodes[idx] = node;
|
||||
} else {
|
||||
nodes = ArrayUtil.grow(nodes);
|
||||
nodes[idx] = node;
|
||||
nodesByLevel.set(level, nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph.get(level).add(new NeighborArray(maxConn + 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int targetNode) {
|
||||
cur = getNeighbors(level, targetNode);
|
||||
upto = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
if (++upto < cur.size()) {
|
||||
return cur.node[upto];
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current number of levels in the graph
|
||||
*
|
||||
* @return the current number of levels in the graph
|
||||
*/
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return numLevels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
|
||||
* level
|
||||
*
|
||||
* @return the graph's current entry node on the top level
|
||||
*/
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return entryNode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,7 +40,6 @@ import org.apache.lucene.document.FieldType;
|
|||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.SortedDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnVectorQuery;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
@ -54,6 +53,8 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -239,7 +240,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
|
||||
Lucene91HnswVectorsReader vectorReader =
|
||||
(Lucene91HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
|
||||
graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD));
|
||||
graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD));
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
|
@ -259,7 +260,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
return values;
|
||||
}
|
||||
|
||||
int[][][] copyGraph(KnnGraphValues graphValues) throws IOException {
|
||||
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
||||
int[][][] graph = new int[graphValues.numLevels()][][];
|
||||
int size = graphValues.size();
|
||||
int[] scratch = new int[maxConn];
|
||||
|
@ -439,7 +440,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
if (vectorReader == null) {
|
||||
continue;
|
||||
}
|
||||
KnnGraphValues graphValues = vectorReader.getGraphValues(vectorField);
|
||||
HnswGraph graphValues = vectorReader.getGraph(vectorField);
|
||||
VectorValues vectorValues = reader.getVectorValues(vectorField);
|
||||
if (vectorValues == null) {
|
||||
assert graphValues == null;
|
||||
|
|
|
@ -50,7 +50,6 @@ import org.apache.lucene.index.DirectoryReader;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
|
@ -252,8 +251,7 @@ public class KnnGraphTester {
|
|||
KnnVectorsReader vectorsReader =
|
||||
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
|
||||
.getFieldReader(KNN_FIELD);
|
||||
KnnGraphValues knnValues =
|
||||
((Lucene91HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD);
|
||||
HnswGraph knnValues = ((Lucene91HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD);
|
||||
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
|
||||
printGraphFanout(knnValues, leafReader.maxDoc());
|
||||
}
|
||||
|
@ -274,7 +272,7 @@ public class KnnGraphTester {
|
|||
}
|
||||
}
|
||||
|
||||
private void dumpGraph(HnswGraph hnsw) {
|
||||
private void dumpGraph(OnHeapHnswGraph hnsw) {
|
||||
for (int i = 0; i < hnsw.size(); i++) {
|
||||
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||
System.out.printf(Locale.ROOT, "%5d", i);
|
||||
|
@ -303,7 +301,7 @@ public class KnnGraphTester {
|
|||
}
|
||||
|
||||
@SuppressForbidden(reason = "Prints stuff")
|
||||
private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException {
|
||||
private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException {
|
||||
int min = Integer.MAX_VALUE, max = 0, total = 0;
|
||||
int count = 0;
|
||||
int[] leafHist = new int[numDocs];
|
||||
|
|
|
@ -37,8 +37,6 @@ import org.apache.lucene.index.DirectoryReader;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
|
@ -51,6 +49,7 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnswGraph extends LuceneTestCase {
|
||||
|
@ -110,19 +109,19 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
||||
assertEquals(indexedDoc, ctx.reader().numDocs());
|
||||
assertVectorsEqual(v3, values);
|
||||
KnnGraphValues graphValues =
|
||||
HnswGraph graphValues =
|
||||
((Lucene91HnswVectorsReader)
|
||||
((PerFieldKnnVectorsFormat.FieldsReader)
|
||||
((CodecReader) ctx.reader()).getVectorReader())
|
||||
.getFieldReader("field"))
|
||||
.getGraphValues("field");
|
||||
.getGraph("field");
|
||||
assertGraphEqual(hnsw, graphValues);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h) throws IOException {
|
||||
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
||||
|
||||
|
@ -159,7 +158,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
|
@ -197,7 +196,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// the first 10 docs must not be deleted to ensure the expected recall
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||
NeighborQueue nn =
|
||||
|
@ -226,7 +225,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
// Skip over half of the documents that are closest to the query vector
|
||||
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
||||
|
@ -354,7 +353,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
|
||||
}
|
||||
|
||||
private void assertLevel0Neighbors(HnswGraph graph, int node, int... expected) {
|
||||
private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
|
||||
Arrays.sort(expected);
|
||||
NeighborArray nn = graph.getNeighbors(0, node);
|
||||
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
|
||||
|
@ -376,7 +375,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
int topK = 5;
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||
|
||||
int totalMatches = 0;
|
||||
|
@ -505,7 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
return value;
|
||||
}
|
||||
|
||||
private Set<Integer> getNeighborNodes(KnnGraphValues g) throws IOException {
|
||||
private Set<Integer> getNeighborNodes(HnswGraph g) throws IOException {
|
||||
Set<Integer> neighbors = new HashSet<>();
|
||||
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
|
||||
neighbors.add(n);
|
||||
|
|
Loading…
Reference in New Issue