mirror of https://github.com/apache/lucene.git
Rename KnnGraphValues -> HnswGraph (#645)
This PR proposes some renames to clarify the code structure. The top-level `KnnGraphValues` is renamed to `HnswGraph`, since it now represents a hierarchical graph. It's also moved from `org.apache.lucene.index` to the `hnsw` package. Other renames: * The old `HnswGraph` -> `OnHeapHnswGraph` * `IndexedKnnGraphValues` -> `OffHeapHnswGraph` (to match `OffHeapVectorValues`)
This commit is contained in:
parent
e7546c2427
commit
eb5bdd7d15
|
@ -30,7 +30,7 @@ import org.apache.lucene.util.hnsw.NeighborArray;
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builder for HNSW graph. See {@link Lucene90HnswGraph} for a gloss on the algorithm and the
|
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
|
||||||
* meaning of the hyperparameters.
|
* meaning of the hyperparameters.
|
||||||
*
|
*
|
||||||
* <p>This class is preserved here only for tests.
|
* <p>This class is preserved here only for tests.
|
||||||
|
@ -53,7 +53,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
private final RandomAccessVectorValues vectorValues;
|
private final RandomAccessVectorValues vectorValues;
|
||||||
private final SplittableRandom random;
|
private final SplittableRandom random;
|
||||||
private final BoundsChecker bound;
|
private final BoundsChecker bound;
|
||||||
final Lucene90HnswGraph hnsw;
|
final Lucene90OnHeapHnswGraph hnsw;
|
||||||
|
|
||||||
private InfoStream infoStream = InfoStream.getDefault();
|
private InfoStream infoStream = InfoStream.getDefault();
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
}
|
}
|
||||||
this.maxConn = maxConn;
|
this.maxConn = maxConn;
|
||||||
this.beamWidth = beamWidth;
|
this.beamWidth = beamWidth;
|
||||||
this.hnsw = new Lucene90HnswGraph(maxConn);
|
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
|
||||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||||
random = new SplittableRandom(seed);
|
random = new SplittableRandom(seed);
|
||||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||||
|
@ -104,7 +104,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||||
* accessor for the vectors
|
* accessor for the vectors
|
||||||
*/
|
*/
|
||||||
public Lucene90HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||||
if (vectors == vectorValues) {
|
if (vectors == vectorValues) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||||
|
@ -143,7 +143,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
void addGraphNode(float[] value) throws IOException {
|
void addGraphNode(float[] value) throws IOException {
|
||||||
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
||||||
NeighborQueue candidates =
|
NeighborQueue candidates =
|
||||||
Lucene90HnswGraph.search(
|
Lucene90OnHeapHnswGraph.search(
|
||||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
||||||
|
|
||||||
int node = hnsw.addNode();
|
int node = hnsw.addNode();
|
||||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
@ -47,6 +46,7 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -243,7 +243,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
// use a seed that is fixed for the index so we get reproducible results for the same query
|
// use a seed that is fixed for the index so we get reproducible results for the same query
|
||||||
final SplittableRandom random = new SplittableRandom(checksumSeed);
|
final SplittableRandom random = new SplittableRandom(checksumSeed);
|
||||||
NeighborQueue results =
|
NeighborQueue results =
|
||||||
Lucene90HnswGraph.search(
|
Lucene90OnHeapHnswGraph.search(
|
||||||
target,
|
target,
|
||||||
k,
|
k,
|
||||||
k,
|
k,
|
||||||
|
@ -291,7 +291,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get knn graph values; used for testing */
|
/** Get knn graph values; used for testing */
|
||||||
public KnnGraphValues getGraphValues(String field) throws IOException {
|
public HnswGraph getGraphValues(String field) throws IOException {
|
||||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
if (info == null) {
|
if (info == null) {
|
||||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||||
|
@ -300,14 +300,14 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
if (entry != null && entry.indexDataLength > 0) {
|
if (entry != null && entry.indexDataLength > 0) {
|
||||||
return getGraphValues(entry);
|
return getGraphValues(entry);
|
||||||
} else {
|
} else {
|
||||||
return KnnGraphValues.EMPTY;
|
return HnswGraph.EMPTY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
|
private HnswGraph getGraphValues(FieldEntry entry) throws IOException {
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
|
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
|
||||||
return new IndexedKnnGraphReader(entry, bytesSlice);
|
return new OffHeapHnswGraph(entry, bytesSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -465,7 +465,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Read the nearest-neighbors graph from the index input */
|
/** Read the nearest-neighbors graph from the index input */
|
||||||
private static final class IndexedKnnGraphReader extends KnnGraphValues {
|
private static final class OffHeapHnswGraph extends HnswGraph {
|
||||||
|
|
||||||
final FieldEntry entry;
|
final FieldEntry entry;
|
||||||
final IndexInput dataIn;
|
final IndexInput dataIn;
|
||||||
|
@ -474,7 +474,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
int arcUpTo;
|
int arcUpTo;
|
||||||
int arc;
|
int arc;
|
||||||
|
|
||||||
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
|
OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) {
|
||||||
this.entry = entry;
|
this.entry = entry;
|
||||||
this.dataIn = dataIn;
|
this.dataIn = dataIn;
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,42 +23,20 @@ import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.SplittableRandom;
|
import java.util.SplittableRandom;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.SparseFixedBitSet;
|
import org.apache.lucene.util.SparseFixedBitSet;
|
||||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Navigable Small-world graph. Provides efficient approximate nearest neighbor search for high
|
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||||
* dimensional vectors. See <a href="https://doi.org/10.1016/j.is.2013.10.006">Approximate nearest
|
* construct the HNSW graph before it's written to the index.
|
||||||
* neighbor algorithm based on navigable small world graphs [2014]</a> and <a
|
|
||||||
* href="https://arxiv.org/abs/1603.09320">this paper [2018]</a> for details.
|
|
||||||
*
|
|
||||||
* <p>The nomenclature is a bit different here from what's used in those papers:
|
|
||||||
*
|
|
||||||
* <h2>Hyperparameters</h2>
|
|
||||||
*
|
|
||||||
* <ul>
|
|
||||||
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the
|
|
||||||
* number of random entry points to sample.
|
|
||||||
* <li><code>beamWidth</code> in {@link Lucene90HnswGraphBuilder} has the same meaning as <code>
|
|
||||||
* efConst </code> in the 2018 paper. It is the number of nearest neighbor candidates to track
|
|
||||||
* while searching the graph for each newly inserted node.
|
|
||||||
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
|
|
||||||
* how many of the <code>efConst</code> neighbors are connected to the new node
|
|
||||||
* </ul>
|
|
||||||
*
|
|
||||||
* <p>Note: The graph may be searched by multiple threads concurrently, but updates are not
|
|
||||||
* thread-safe. Also note: there is no notion of deletions. Document searching built on top of this
|
|
||||||
* must do its own deletion-filtering.
|
|
||||||
*
|
|
||||||
* <p>Graph building logic is preserved here only for tests.
|
|
||||||
*/
|
*/
|
||||||
public final class Lucene90HnswGraph extends KnnGraphValues {
|
public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
|
|
||||||
private final int maxConn;
|
private final int maxConn;
|
||||||
|
|
||||||
|
@ -71,7 +49,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues {
|
||||||
private int upto;
|
private int upto;
|
||||||
private NeighborArray cur;
|
private NeighborArray cur;
|
||||||
|
|
||||||
Lucene90HnswGraph(int maxConn) {
|
Lucene90OnHeapHnswGraph(int maxConn) {
|
||||||
graph = new ArrayList<>();
|
graph = new ArrayList<>();
|
||||||
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
|
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
|
||||||
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
|
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
|
||||||
|
@ -100,7 +78,7 @@ public final class Lucene90HnswGraph extends KnnGraphValues {
|
||||||
int numSeed,
|
int numSeed,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues vectors,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
KnnGraphValues graphValues,
|
HnswGraph graphValues,
|
||||||
Bits acceptOrds,
|
Bits acceptOrds,
|
||||||
SplittableRandom random)
|
SplittableRandom random)
|
||||||
throws IOException {
|
throws IOException {
|
|
@ -241,7 +241,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
beamWidth,
|
beamWidth,
|
||||||
Lucene90HnswGraphBuilder.randSeed);
|
Lucene90HnswGraphBuilder.randSeed);
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
Lucene90HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||||
|
|
||||||
for (int ord = 0; ord < offsets.length; ord++) {
|
for (int ord = 0; ord < offsets.length; ord++) {
|
||||||
// write graph
|
// write graph
|
||||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
@ -46,6 +45,7 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
|
||||||
|
@ -235,7 +235,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
k,
|
k,
|
||||||
vectorValues,
|
vectorValues,
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
getGraphValues(fieldEntry),
|
getGraph(fieldEntry),
|
||||||
getAcceptOrds(acceptDocs, fieldEntry));
|
getAcceptOrds(acceptDocs, fieldEntry));
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -277,23 +277,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get knn graph values; used for testing */
|
/** Get knn graph values; used for testing */
|
||||||
public KnnGraphValues getGraphValues(String field) throws IOException {
|
public HnswGraph getGraph(String field) throws IOException {
|
||||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
if (info == null) {
|
if (info == null) {
|
||||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||||
}
|
}
|
||||||
FieldEntry entry = fields.get(field);
|
FieldEntry entry = fields.get(field);
|
||||||
if (entry != null && entry.vectorIndexLength > 0) {
|
if (entry != null && entry.vectorIndexLength > 0) {
|
||||||
return getGraphValues(entry);
|
return getGraph(entry);
|
||||||
} else {
|
} else {
|
||||||
return KnnGraphValues.EMPTY;
|
return HnswGraph.EMPTY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
|
private HnswGraph getGraph(FieldEntry entry) throws IOException {
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
|
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
|
||||||
return new IndexedKnnGraphReader(entry, bytesSlice);
|
return new OffHeapHnswGraph(entry, bytesSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -478,7 +478,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Read the nearest-neighbors graph from the index input */
|
/** Read the nearest-neighbors graph from the index input */
|
||||||
private static final class IndexedKnnGraphReader extends KnnGraphValues {
|
private static final class OffHeapHnswGraph extends HnswGraph {
|
||||||
|
|
||||||
final IndexInput dataIn;
|
final IndexInput dataIn;
|
||||||
final int[][] nodesByLevel;
|
final int[][] nodesByLevel;
|
||||||
|
@ -492,7 +492,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
int arcUpTo;
|
int arcUpTo;
|
||||||
int arc;
|
int arc;
|
||||||
|
|
||||||
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
|
OffHeapHnswGraph(FieldEntry entry, IndexInput dataIn) {
|
||||||
this.dataIn = dataIn;
|
this.dataIn = dataIn;
|
||||||
this.nodesByLevel = entry.nodesByLevel;
|
this.nodesByLevel = entry.nodesByLevel;
|
||||||
this.numLevels = entry.numLevels;
|
this.numLevels = entry.numLevels;
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
|
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
@ -36,9 +35,10 @@ import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||||
|
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes vector values and knn graphs to index segments.
|
* Writes vector values and knn graphs to index segments.
|
||||||
|
@ -141,7 +141,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
|
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
|
||||||
new Lucene91HnswVectorsReader.OffHeapVectorValues(
|
new Lucene91HnswVectorsReader.OffHeapVectorValues(
|
||||||
vectors.dimension(), docIds, vectorDataInput);
|
vectors.dimension(), docIds, vectorDataInput);
|
||||||
HnswGraph graph =
|
OnHeapHnswGraph graph =
|
||||||
offHeapVectors.size() == 0
|
offHeapVectors.size() == 0
|
||||||
? null
|
? null
|
||||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
||||||
|
@ -197,7 +197,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
long vectorIndexOffset,
|
long vectorIndexOffset,
|
||||||
long vectorIndexLength,
|
long vectorIndexLength,
|
||||||
int[] docIds,
|
int[] docIds,
|
||||||
HnswGraph graph)
|
OnHeapHnswGraph graph)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
meta.writeInt(field.number);
|
meta.writeInt(field.number);
|
||||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||||
|
@ -232,7 +232,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private HnswGraph writeGraph(
|
private OnHeapHnswGraph writeGraph(
|
||||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
new HnswGraphBuilder(
|
new HnswGraphBuilder(
|
||||||
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
|
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||||
|
|
||||||
// write vectors' neighbours on each level into the vectorIndex file
|
// write vectors' neighbours on each level into the vectorIndex file
|
||||||
int countOnLevel0 = graph.size();
|
int countOnLevel0 = graph.size();
|
||||||
|
|
|
@ -1,151 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.lucene.index;
|
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.NoSuchElementException;
|
|
||||||
import java.util.PrimitiveIterator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Access to per-document neighbor lists in a (hierarchical) knn search graph.
|
|
||||||
*
|
|
||||||
* @lucene.experimental
|
|
||||||
*/
|
|
||||||
public abstract class KnnGraphValues {
|
|
||||||
|
|
||||||
/** Sole constructor */
|
|
||||||
protected KnnGraphValues() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Move the pointer to exactly the given {@code level}'s {@code target}. After this method
|
|
||||||
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
|
||||||
*
|
|
||||||
* @param level level of the graph
|
|
||||||
* @param target ordinal of a node in the graph, must be ≥ 0 and < {@link
|
|
||||||
* VectorValues#size()}.
|
|
||||||
*/
|
|
||||||
public abstract void seek(int level, int target) throws IOException;
|
|
||||||
|
|
||||||
/** Returns the number of nodes in the graph */
|
|
||||||
public abstract int size();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
|
||||||
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
|
||||||
*
|
|
||||||
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
|
||||||
*/
|
|
||||||
public abstract int nextNeighbor() throws IOException;
|
|
||||||
|
|
||||||
/** Returns the number of levels of the graph */
|
|
||||||
public abstract int numLevels() throws IOException;
|
|
||||||
|
|
||||||
/** Returns graph's entry point on the top level * */
|
|
||||||
public abstract int entryNode() throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get all nodes on a given level as node 0th ordinals
|
|
||||||
*
|
|
||||||
* @param level level for which to get all nodes
|
|
||||||
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
|
|
||||||
*/
|
|
||||||
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
|
|
||||||
|
|
||||||
/** Empty graph value */
|
|
||||||
public static KnnGraphValues EMPTY =
|
|
||||||
new KnnGraphValues() {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int nextNeighbor() {
|
|
||||||
return NO_MORE_DOCS;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void seek(int level, int target) {}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int size() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numLevels() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int entryNode() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NodesIterator getNodesOnLevel(int level) {
|
|
||||||
return NodesIterator.EMPTY;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
|
|
||||||
* number of nodes to be iterated over.
|
|
||||||
*/
|
|
||||||
public static final class NodesIterator implements PrimitiveIterator.OfInt {
|
|
||||||
static NodesIterator EMPTY = new NodesIterator(0);
|
|
||||||
|
|
||||||
private final int[] nodes;
|
|
||||||
private final int size;
|
|
||||||
int cur = 0;
|
|
||||||
|
|
||||||
/** Constructor for iterator based on the nodes array up to the size */
|
|
||||||
public NodesIterator(int[] nodes, int size) {
|
|
||||||
assert nodes != null;
|
|
||||||
assert size <= nodes.length;
|
|
||||||
this.nodes = nodes;
|
|
||||||
this.size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Constructor for iterator based on the size */
|
|
||||||
public NodesIterator(int size) {
|
|
||||||
this.nodes = null;
|
|
||||||
this.size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int nextInt() {
|
|
||||||
if (hasNext() == false) {
|
|
||||||
throw new NoSuchElementException();
|
|
||||||
}
|
|
||||||
if (nodes == null) {
|
|
||||||
return cur++;
|
|
||||||
} else {
|
|
||||||
return nodes[cur++];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasNext() {
|
|
||||||
return cur < size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** The number of elements in this iterator * */
|
|
||||||
public int size() {
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,11 +19,10 @@ package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.NoSuchElementException;
|
||||||
import java.util.List;
|
import java.util.PrimitiveIterator;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hierarchical Navigable Small World graph. Provides efficient approximate nearest neighbor search
|
* Hierarchical Navigable Small World graph. Provides efficient approximate nearest neighbor search
|
||||||
|
@ -47,142 +46,124 @@ import org.apache.lucene.util.ArrayUtil;
|
||||||
* thread-safe. The search method optionally takes a set of "accepted nodes", which can be used to
|
* thread-safe. The search method optionally takes a set of "accepted nodes", which can be used to
|
||||||
* exclude deleted documents.
|
* exclude deleted documents.
|
||||||
*/
|
*/
|
||||||
public final class HnswGraph extends KnnGraphValues {
|
public abstract class HnswGraph {
|
||||||
|
|
||||||
private final int maxConn;
|
/** Sole constructor */
|
||||||
private int numLevels; // the current number of levels in the graph
|
protected HnswGraph() {}
|
||||||
private int entryNode; // the current graph entry node on the top level
|
|
||||||
|
|
||||||
// Nodes by level expressed as the level 0's nodes' ordinals.
|
|
||||||
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
|
|
||||||
private final List<int[]> nodesByLevel;
|
|
||||||
|
|
||||||
// graph is a list of graph levels.
|
|
||||||
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
|
||||||
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
|
||||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
|
||||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
|
||||||
private final List<List<NeighborArray>> graph;
|
|
||||||
|
|
||||||
// KnnGraphValues iterator members
|
|
||||||
private int upto;
|
|
||||||
private NeighborArray cur;
|
|
||||||
|
|
||||||
HnswGraph(int maxConn, int levelOfFirstNode) {
|
|
||||||
this.maxConn = maxConn;
|
|
||||||
this.numLevels = levelOfFirstNode + 1;
|
|
||||||
this.graph = new ArrayList<>(numLevels);
|
|
||||||
this.entryNode = 0;
|
|
||||||
for (int i = 0; i < numLevels; i++) {
|
|
||||||
graph.add(new ArrayList<>());
|
|
||||||
// Typically with diversity criteria we see nodes not fully occupied;
|
|
||||||
// average fanout seems to be about 1/2 maxConn.
|
|
||||||
// There is some indexing time penalty for under-allocating, but saves RAM
|
|
||||||
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
|
|
||||||
}
|
|
||||||
|
|
||||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
|
||||||
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
|
|
||||||
for (int l = 1; l < numLevels; l++) {
|
|
||||||
nodesByLevel.add(new int[] {0});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the {@link NeighborQueue} connected to the given node.
|
* Move the pointer to exactly the given {@code level}'s {@code target}. After this method
|
||||||
|
* returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
||||||
*
|
*
|
||||||
* @param level level of the graph
|
* @param level level of the graph
|
||||||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
* @param target ordinal of a node in the graph, must be ≥ 0 and < {@link
|
||||||
|
* VectorValues#size()}.
|
||||||
*/
|
*/
|
||||||
public NeighborArray getNeighbors(int level, int node) {
|
public abstract void seek(int level, int target) throws IOException;
|
||||||
if (level == 0) {
|
|
||||||
return graph.get(level).get(node);
|
|
||||||
}
|
|
||||||
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
|
|
||||||
assert nodeIndex >= 0;
|
|
||||||
return graph.get(level).get(nodeIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
/** Returns the number of nodes in the graph */
|
||||||
public int size() {
|
public abstract int size();
|
||||||
return graph.get(0).size(); // all nodes are located on the 0th level
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add node on the given level
|
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||||
|
* NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
|
||||||
*
|
*
|
||||||
* @param level level to add a node on
|
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
||||||
* @param node the node to add, represented as an ordinal on the level 0.
|
|
||||||
*/
|
*/
|
||||||
public void addNode(int level, int node) {
|
public abstract int nextNeighbor() throws IOException;
|
||||||
if (level > 0) {
|
|
||||||
// if the new node introduces a new level, add more levels to the graph,
|
/** Returns the number of levels of the graph */
|
||||||
// and make this node the graph's new entry point
|
public abstract int numLevels() throws IOException;
|
||||||
if (level >= numLevels) {
|
|
||||||
for (int i = numLevels; i <= level; i++) {
|
/** Returns graph's entry point on the top level * */
|
||||||
graph.add(new ArrayList<>());
|
public abstract int entryNode() throws IOException;
|
||||||
nodesByLevel.add(new int[] {node});
|
|
||||||
|
/**
|
||||||
|
* Get all nodes on a given level as node 0th ordinals
|
||||||
|
*
|
||||||
|
* @param level level for which to get all nodes
|
||||||
|
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
|
||||||
|
*/
|
||||||
|
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
|
||||||
|
|
||||||
|
/** Empty graph value */
|
||||||
|
public static HnswGraph EMPTY =
|
||||||
|
new HnswGraph() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextNeighbor() {
|
||||||
|
return NO_MORE_DOCS;
|
||||||
}
|
}
|
||||||
numLevels = level + 1;
|
|
||||||
entryNode = node;
|
@Override
|
||||||
|
public void seek(int level, int target) {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numLevels() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int entryNode() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NodesIterator getNodesOnLevel(int level) {
|
||||||
|
return NodesIterator.EMPTY;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
|
||||||
|
* number of nodes to be iterated over.
|
||||||
|
*/
|
||||||
|
public static final class NodesIterator implements PrimitiveIterator.OfInt {
|
||||||
|
static NodesIterator EMPTY = new NodesIterator(0);
|
||||||
|
|
||||||
|
private final int[] nodes;
|
||||||
|
private final int size;
|
||||||
|
int cur = 0;
|
||||||
|
|
||||||
|
/** Constructor for iterator based on the nodes array up to the size */
|
||||||
|
public NodesIterator(int[] nodes, int size) {
|
||||||
|
assert nodes != null;
|
||||||
|
assert size <= nodes.length;
|
||||||
|
this.nodes = nodes;
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Constructor for iterator based on the size */
|
||||||
|
public NodesIterator(int size) {
|
||||||
|
this.nodes = null;
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextInt() {
|
||||||
|
if (hasNext() == false) {
|
||||||
|
throw new NoSuchElementException();
|
||||||
|
}
|
||||||
|
if (nodes == null) {
|
||||||
|
return cur++;
|
||||||
} else {
|
} else {
|
||||||
// Add this node id to this level's nodes
|
return nodes[cur++];
|
||||||
int[] nodes = nodesByLevel.get(level);
|
|
||||||
int idx = graph.get(level).size();
|
|
||||||
if (idx < nodes.length) {
|
|
||||||
nodes[idx] = node;
|
|
||||||
} else {
|
|
||||||
nodes = ArrayUtil.grow(nodes);
|
|
||||||
nodes[idx] = node;
|
|
||||||
nodesByLevel.set(level, nodes);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.get(level).add(new NeighborArray(maxConn + 1));
|
@Override
|
||||||
}
|
public boolean hasNext() {
|
||||||
|
return cur < size;
|
||||||
@Override
|
|
||||||
public void seek(int level, int targetNode) {
|
|
||||||
cur = getNeighbors(level, targetNode);
|
|
||||||
upto = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int nextNeighbor() {
|
|
||||||
if (++upto < cur.size()) {
|
|
||||||
return cur.node[upto];
|
|
||||||
}
|
}
|
||||||
return NO_MORE_DOCS;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/** The number of elements in this iterator * */
|
||||||
* Returns the current number of levels in the graph
|
public int size() {
|
||||||
*
|
return size;
|
||||||
* @return the current number of levels in the graph
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public int numLevels() {
|
|
||||||
return numLevels;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
|
|
||||||
* level
|
|
||||||
*
|
|
||||||
* @return the graph's current entry node on the top level
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public int entryNode() {
|
|
||||||
return entryNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NodesIterator getNodesOnLevel(int level) {
|
|
||||||
if (level == 0) {
|
|
||||||
return new NodesIterator(size());
|
|
||||||
} else {
|
|
||||||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ public final class HnswGraphBuilder {
|
||||||
private final BoundsChecker bound;
|
private final BoundsChecker bound;
|
||||||
private final HnswGraphSearcher graphSearcher;
|
private final HnswGraphSearcher graphSearcher;
|
||||||
|
|
||||||
final HnswGraph hnsw;
|
final OnHeapHnswGraph hnsw;
|
||||||
|
|
||||||
private InfoStream infoStream = InfoStream.getDefault();
|
private InfoStream infoStream = InfoStream.getDefault();
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ public final class HnswGraphBuilder {
|
||||||
this.ml = 1 / Math.log(1.0 * maxConn);
|
this.ml = 1 / Math.log(1.0 * maxConn);
|
||||||
this.random = new SplittableRandom(seed);
|
this.random = new SplittableRandom(seed);
|
||||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||||
this.hnsw = new HnswGraph(maxConn, levelOfFirstNode);
|
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
||||||
this.graphSearcher =
|
this.graphSearcher =
|
||||||
new HnswGraphSearcher(
|
new HnswGraphSearcher(
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
|
@ -113,7 +113,7 @@ public final class HnswGraphBuilder {
|
||||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||||
* accessor for the vectors
|
* accessor for the vectors
|
||||||
*/
|
*/
|
||||||
public HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||||
if (vectors == vectorValues) {
|
if (vectors == vectorValues) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.BitSet;
|
import org.apache.lucene.util.BitSet;
|
||||||
|
@ -62,8 +61,8 @@ public final class HnswGraphSearcher {
|
||||||
* @param topK the number of nodes to be returned
|
* @param topK the number of nodes to be returned
|
||||||
* @param vectors the vector values
|
* @param vectors the vector values
|
||||||
* @param similarityFunction the similarity function to compare vectors
|
* @param similarityFunction the similarity function to compare vectors
|
||||||
* @param graphValues the graph values. May represent the entire graph, or a level in a
|
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
|
||||||
* hierarchical graph.
|
* graph.
|
||||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||||
* {@code null} if they are all allowed to match.
|
* {@code null} if they are all allowed to match.
|
||||||
* @return a priority queue holding the closest neighbors found
|
* @return a priority queue holding the closest neighbors found
|
||||||
|
@ -73,7 +72,7 @@ public final class HnswGraphSearcher {
|
||||||
int topK,
|
int topK,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues vectors,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
KnnGraphValues graphValues,
|
HnswGraph graph,
|
||||||
Bits acceptOrds)
|
Bits acceptOrds)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
HnswGraphSearcher graphSearcher =
|
HnswGraphSearcher graphSearcher =
|
||||||
|
@ -82,12 +81,12 @@ public final class HnswGraphSearcher {
|
||||||
new NeighborQueue(topK, similarityFunction.reversed == false),
|
new NeighborQueue(topK, similarityFunction.reversed == false),
|
||||||
new SparseFixedBitSet(vectors.size()));
|
new SparseFixedBitSet(vectors.size()));
|
||||||
NeighborQueue results;
|
NeighborQueue results;
|
||||||
int[] eps = new int[] {graphValues.entryNode()};
|
int[] eps = new int[] {graph.entryNode()};
|
||||||
for (int level = graphValues.numLevels() - 1; level >= 1; level--) {
|
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graphValues, null);
|
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null);
|
||||||
eps[0] = results.pop();
|
eps[0] = results.pop();
|
||||||
}
|
}
|
||||||
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graphValues, acceptOrds);
|
results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds);
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,7 +98,7 @@ public final class HnswGraphSearcher {
|
||||||
* @param level level to search
|
* @param level level to search
|
||||||
* @param eps the entry points for search at this level expressed as level 0th ordinals
|
* @param eps the entry points for search at this level expressed as level 0th ordinals
|
||||||
* @param vectors vector values
|
* @param vectors vector values
|
||||||
* @param graphValues the graph values
|
* @param graph the graph values
|
||||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||||
* {@code null} if they are all allowed to match.
|
* {@code null} if they are all allowed to match.
|
||||||
* @return a priority queue holding the closest neighbors found
|
* @return a priority queue holding the closest neighbors found
|
||||||
|
@ -110,10 +109,10 @@ public final class HnswGraphSearcher {
|
||||||
int level,
|
int level,
|
||||||
final int[] eps,
|
final int[] eps,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues vectors,
|
||||||
KnnGraphValues graphValues,
|
HnswGraph graph,
|
||||||
Bits acceptOrds)
|
Bits acceptOrds)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int size = graphValues.size();
|
int size = graph.size();
|
||||||
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
|
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
|
||||||
clearScratchState();
|
clearScratchState();
|
||||||
|
|
||||||
|
@ -140,9 +139,9 @@ public final class HnswGraphSearcher {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
int topCandidateNode = candidates.pop();
|
int topCandidateNode = candidates.pop();
|
||||||
graphValues.seek(level, topCandidateNode);
|
graph.seek(level, topCandidateNode);
|
||||||
int friendOrd;
|
int friendOrd;
|
||||||
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
|
||||||
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
|
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
|
||||||
if (visited.getAndSet(friendOrd)) {
|
if (visited.getAndSet(friendOrd)) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -0,0 +1,169 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||||
|
* construct the HNSW graph before it's written to the index.
|
||||||
|
*/
|
||||||
|
public final class OnHeapHnswGraph extends HnswGraph {
|
||||||
|
|
||||||
|
private final int maxConn;
|
||||||
|
private int numLevels; // the current number of levels in the graph
|
||||||
|
private int entryNode; // the current graph entry node on the top level
|
||||||
|
|
||||||
|
// Nodes by level expressed as the level 0's nodes' ordinals.
|
||||||
|
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
|
||||||
|
private final List<int[]> nodesByLevel;
|
||||||
|
|
||||||
|
// graph is a list of graph levels.
|
||||||
|
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||||
|
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
||||||
|
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||||
|
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||||
|
private final List<List<NeighborArray>> graph;
|
||||||
|
|
||||||
|
// KnnGraphValues iterator members
|
||||||
|
private int upto;
|
||||||
|
private NeighborArray cur;
|
||||||
|
|
||||||
|
OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
|
||||||
|
this.maxConn = maxConn;
|
||||||
|
this.numLevels = levelOfFirstNode + 1;
|
||||||
|
this.graph = new ArrayList<>(numLevels);
|
||||||
|
this.entryNode = 0;
|
||||||
|
for (int i = 0; i < numLevels; i++) {
|
||||||
|
graph.add(new ArrayList<>());
|
||||||
|
// Typically with diversity criteria we see nodes not fully occupied;
|
||||||
|
// average fanout seems to be about 1/2 maxConn.
|
||||||
|
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||||
|
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||||
|
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||||
|
for (int l = 1; l < numLevels; l++) {
|
||||||
|
nodesByLevel.add(new int[] {0});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link NeighborQueue} connected to the given node.
|
||||||
|
*
|
||||||
|
* @param level level of the graph
|
||||||
|
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||||
|
*/
|
||||||
|
public NeighborArray getNeighbors(int level, int node) {
|
||||||
|
if (level == 0) {
|
||||||
|
return graph.get(level).get(node);
|
||||||
|
}
|
||||||
|
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
|
||||||
|
assert nodeIndex >= 0;
|
||||||
|
return graph.get(level).get(nodeIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return graph.get(0).size(); // all nodes are located on the 0th level
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add node on the given level
|
||||||
|
*
|
||||||
|
* @param level level to add a node on
|
||||||
|
* @param node the node to add, represented as an ordinal on the level 0.
|
||||||
|
*/
|
||||||
|
public void addNode(int level, int node) {
|
||||||
|
if (level > 0) {
|
||||||
|
// if the new node introduces a new level, add more levels to the graph,
|
||||||
|
// and make this node the graph's new entry point
|
||||||
|
if (level >= numLevels) {
|
||||||
|
for (int i = numLevels; i <= level; i++) {
|
||||||
|
graph.add(new ArrayList<>());
|
||||||
|
nodesByLevel.add(new int[] {node});
|
||||||
|
}
|
||||||
|
numLevels = level + 1;
|
||||||
|
entryNode = node;
|
||||||
|
} else {
|
||||||
|
// Add this node id to this level's nodes
|
||||||
|
int[] nodes = nodesByLevel.get(level);
|
||||||
|
int idx = graph.get(level).size();
|
||||||
|
if (idx < nodes.length) {
|
||||||
|
nodes[idx] = node;
|
||||||
|
} else {
|
||||||
|
nodes = ArrayUtil.grow(nodes);
|
||||||
|
nodes[idx] = node;
|
||||||
|
nodesByLevel.set(level, nodes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.get(level).add(new NeighborArray(maxConn + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void seek(int level, int targetNode) {
|
||||||
|
cur = getNeighbors(level, targetNode);
|
||||||
|
upto = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextNeighbor() {
|
||||||
|
if (++upto < cur.size()) {
|
||||||
|
return cur.node[upto];
|
||||||
|
}
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the current number of levels in the graph
|
||||||
|
*
|
||||||
|
* @return the current number of levels in the graph
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int numLevels() {
|
||||||
|
return numLevels;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
|
||||||
|
* level
|
||||||
|
*
|
||||||
|
* @return the graph's current entry node on the top level
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int entryNode() {
|
||||||
|
return entryNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NodesIterator getNodesOnLevel(int level) {
|
||||||
|
if (level == 0) {
|
||||||
|
return new NodesIterator(size());
|
||||||
|
} else {
|
||||||
|
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,7 +40,6 @@ import org.apache.lucene.document.FieldType;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
import org.apache.lucene.document.SortedDocValuesField;
|
import org.apache.lucene.document.SortedDocValuesField;
|
||||||
import org.apache.lucene.document.StringField;
|
import org.apache.lucene.document.StringField;
|
||||||
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
|
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.KnnVectorQuery;
|
import org.apache.lucene.search.KnnVectorQuery;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
@ -54,6 +53,8 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -239,7 +240,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
|
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
|
||||||
Lucene91HnswVectorsReader vectorReader =
|
Lucene91HnswVectorsReader vectorReader =
|
||||||
(Lucene91HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
|
(Lucene91HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
|
||||||
graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD));
|
graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return graph;
|
return graph;
|
||||||
|
@ -259,7 +260,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
int[][][] copyGraph(KnnGraphValues graphValues) throws IOException {
|
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
||||||
int[][][] graph = new int[graphValues.numLevels()][][];
|
int[][][] graph = new int[graphValues.numLevels()][][];
|
||||||
int size = graphValues.size();
|
int size = graphValues.size();
|
||||||
int[] scratch = new int[maxConn];
|
int[] scratch = new int[maxConn];
|
||||||
|
@ -439,7 +440,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
if (vectorReader == null) {
|
if (vectorReader == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
KnnGraphValues graphValues = vectorReader.getGraphValues(vectorField);
|
HnswGraph graphValues = vectorReader.getGraph(vectorField);
|
||||||
VectorValues vectorValues = reader.getVectorValues(vectorField);
|
VectorValues vectorValues = reader.getVectorValues(vectorField);
|
||||||
if (vectorValues == null) {
|
if (vectorValues == null) {
|
||||||
assert graphValues == null;
|
assert graphValues == null;
|
||||||
|
|
|
@ -50,7 +50,6 @@ import org.apache.lucene.index.DirectoryReader;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriter;
|
import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
|
@ -252,8 +251,7 @@ public class KnnGraphTester {
|
||||||
KnnVectorsReader vectorsReader =
|
KnnVectorsReader vectorsReader =
|
||||||
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
|
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
|
||||||
.getFieldReader(KNN_FIELD);
|
.getFieldReader(KNN_FIELD);
|
||||||
KnnGraphValues knnValues =
|
HnswGraph knnValues = ((Lucene91HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD);
|
||||||
((Lucene91HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD);
|
|
||||||
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
|
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
|
||||||
printGraphFanout(knnValues, leafReader.maxDoc());
|
printGraphFanout(knnValues, leafReader.maxDoc());
|
||||||
}
|
}
|
||||||
|
@ -274,7 +272,7 @@ public class KnnGraphTester {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void dumpGraph(HnswGraph hnsw) {
|
private void dumpGraph(OnHeapHnswGraph hnsw) {
|
||||||
for (int i = 0; i < hnsw.size(); i++) {
|
for (int i = 0; i < hnsw.size(); i++) {
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||||
System.out.printf(Locale.ROOT, "%5d", i);
|
System.out.printf(Locale.ROOT, "%5d", i);
|
||||||
|
@ -303,7 +301,7 @@ public class KnnGraphTester {
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressForbidden(reason = "Prints stuff")
|
@SuppressForbidden(reason = "Prints stuff")
|
||||||
private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException {
|
private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException {
|
||||||
int min = Integer.MAX_VALUE, max = 0, total = 0;
|
int min = Integer.MAX_VALUE, max = 0, total = 0;
|
||||||
int count = 0;
|
int count = 0;
|
||||||
int[] leafHist = new int[numDocs];
|
int[] leafHist = new int[numDocs];
|
||||||
|
|
|
@ -37,8 +37,6 @@ import org.apache.lucene.index.DirectoryReader;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriter;
|
import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.KnnGraphValues;
|
|
||||||
import org.apache.lucene.index.KnnGraphValues.NodesIterator;
|
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
@ -51,6 +49,7 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||||
|
|
||||||
/** Tests HNSW KNN graphs */
|
/** Tests HNSW KNN graphs */
|
||||||
public class TestHnswGraph extends LuceneTestCase {
|
public class TestHnswGraph extends LuceneTestCase {
|
||||||
|
@ -110,19 +109,19 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
||||||
assertEquals(indexedDoc, ctx.reader().numDocs());
|
assertEquals(indexedDoc, ctx.reader().numDocs());
|
||||||
assertVectorsEqual(v3, values);
|
assertVectorsEqual(v3, values);
|
||||||
KnnGraphValues graphValues =
|
HnswGraph graphValues =
|
||||||
((Lucene91HnswVectorsReader)
|
((Lucene91HnswVectorsReader)
|
||||||
((PerFieldKnnVectorsFormat.FieldsReader)
|
((PerFieldKnnVectorsFormat.FieldsReader)
|
||||||
((CodecReader) ctx.reader()).getVectorReader())
|
((CodecReader) ctx.reader()).getVectorReader())
|
||||||
.getFieldReader("field"))
|
.getFieldReader("field"))
|
||||||
.getGraphValues("field");
|
.getGraph("field");
|
||||||
assertGraphEqual(hnsw, graphValues);
|
assertGraphEqual(hnsw, graphValues);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h) throws IOException {
|
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||||
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
||||||
|
|
||||||
|
@ -159,7 +158,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder builder =
|
||||||
new HnswGraphBuilder(
|
new HnswGraphBuilder(
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||||
HnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
// run some searches
|
// run some searches
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
|
@ -197,7 +196,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder builder =
|
||||||
new HnswGraphBuilder(
|
new HnswGraphBuilder(
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||||
HnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
// the first 10 docs must not be deleted to ensure the expected recall
|
// the first 10 docs must not be deleted to ensure the expected recall
|
||||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
|
@ -226,7 +225,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder builder =
|
||||||
new HnswGraphBuilder(
|
new HnswGraphBuilder(
|
||||||
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
|
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
|
||||||
HnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
|
|
||||||
// Skip over half of the documents that are closest to the query vector
|
// Skip over half of the documents that are closest to the query vector
|
||||||
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
||||||
|
@ -354,7 +353,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
|
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertLevel0Neighbors(HnswGraph graph, int node, int... expected) {
|
private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
|
||||||
Arrays.sort(expected);
|
Arrays.sort(expected);
|
||||||
NeighborArray nn = graph.getNeighbors(0, node);
|
NeighborArray nn = graph.getNeighbors(0, node);
|
||||||
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
|
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
|
||||||
|
@ -376,7 +375,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
int topK = 5;
|
int topK = 5;
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder builder =
|
||||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
|
new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
|
||||||
HnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||||
|
|
||||||
int totalMatches = 0;
|
int totalMatches = 0;
|
||||||
|
@ -505,7 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<Integer> getNeighborNodes(KnnGraphValues g) throws IOException {
|
private Set<Integer> getNeighborNodes(HnswGraph g) throws IOException {
|
||||||
Set<Integer> neighbors = new HashSet<>();
|
Set<Integer> neighbors = new HashSet<>();
|
||||||
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
|
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
|
||||||
neighbors.add(n);
|
neighbors.add(n);
|
||||||
|
|
Loading…
Reference in New Issue