mirror of https://github.com/apache/lucene.git
LUCENE-10527 Use 2*maxConn for last layer in HNSW (#872)
The original HNSW paper (https://arxiv.org/pdf/1603.09320.pdf) suggests to use a different maxConn for the upper layers vs. the bottom one (which contains the full neighborhood graph). Specifically, they suggest using maxConn=M for upper layers and maxConn=2*M for the bottom. This patch ensures that we follow this recommendation and use maxConn=2*M for the bottom layer.
This commit is contained in:
parent
8f89db8048
commit
ea5c40686f
|
@ -115,6 +115,8 @@ Improvements
|
|||
* LUCENE-9848: Correctly sort HNSW graph neighbors when applying diversity criterion (Mayya
|
||||
Sharipova, Michael Sokolov)
|
||||
|
||||
* LUCENE-10527: Use 2*maxConn for the last layer in HNSW (Mayya Sharipova)
|
||||
|
||||
Optimizations
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import static java.lang.Math.log;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
||||
* hyperparameters.
|
||||
*/
|
||||
public final class Lucene91HnswGraphBuilder {
|
||||
|
||||
/** Default random seed for level generation * */
|
||||
private static final long DEFAULT_RAND_SEED = 42;
|
||||
/** A name for the HNSW component for the info-stream * */
|
||||
public static final String HNSW_COMPONENT = "HNSW";
|
||||
|
||||
/** Random seed for level generation; public to expose for testing * */
|
||||
public static long randSeed = DEFAULT_RAND_SEED;
|
||||
|
||||
private final int maxConn;
|
||||
private final int beamWidth;
|
||||
private final double ml;
|
||||
private final Lucene91NeighborArray scratch;
|
||||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final BoundsChecker bound;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
|
||||
final Lucene91OnHeapHnswGraph hnsw;
|
||||
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private RandomAccessVectorValues buildVectors;
|
||||
|
||||
/**
|
||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
|
||||
*
|
||||
* @param vectors the vectors whose relations are represented by the graph - must provide a
|
||||
* different view over those vectors than the one used to add via addGraphNode.
|
||||
* @param maxConn the number of connections to make when adding a new graph node; roughly speaking
|
||||
* the graph fanout.
|
||||
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene91HnswGraphBuilder(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
long seed)
|
||||
throws IOException {
|
||||
vectorValues = vectors.randomAccess();
|
||||
buildVectors = vectors.randomAccess();
|
||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||
if (maxConn <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
}
|
||||
if (beamWidth <= 0) {
|
||||
throw new IllegalArgumentException("beamWidth must be positive");
|
||||
}
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
// normalization factor for level generation; currently not configurable
|
||||
this.ml = 1 / Math.log(1.0 * maxConn);
|
||||
this.random = new SplittableRandom(seed);
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
similarityFunction,
|
||||
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
|
||||
new FixedBitSet(vectorValues.size()));
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from two copies of a random access VectorValues. Providing two copies
|
||||
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* returned values.
|
||||
*
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
}
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
|
||||
}
|
||||
long start = System.nanoTime(), t = start;
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
for (int node = 1; node < vectors.size(); node++) {
|
||||
addGraphNode(node, vectors.vectorValue(node));
|
||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
t = printGraphBuildStatus(node, start, t);
|
||||
}
|
||||
}
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
/** Set info-stream to output debugging information * */
|
||||
public void setInfoStream(InfoStream infoStream) {
|
||||
this.infoStream = infoStream;
|
||||
}
|
||||
|
||||
/** Inserts a doc with vector value to the graph */
|
||||
void addGraphNode(int node, float[] value) throws IOException {
|
||||
NeighborQueue candidates;
|
||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||
int curMaxLevel = hnsw.numLevels() - 1;
|
||||
int[] eps = new int[] {hnsw.entryNode()};
|
||||
|
||||
// if a node introduces new levels to the graph, add this new node on new levels
|
||||
for (int level = nodeLevel; level > curMaxLevel; level--) {
|
||||
hnsw.addNode(level, node);
|
||||
}
|
||||
|
||||
// for levels > nodeLevel search with topk = 1
|
||||
for (int level = curMaxLevel; level > nodeLevel; level--) {
|
||||
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
|
||||
eps = new int[] {candidates.pop()};
|
||||
}
|
||||
// for levels <= nodeLevel search with topk = beamWidth, and add connections
|
||||
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
|
||||
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
|
||||
eps = candidates.nodes();
|
||||
hnsw.addNode(level, node);
|
||||
addDiverseNeighbors(level, node, candidates);
|
||||
}
|
||||
}
|
||||
|
||||
private long printGraphBuildStatus(int node, long start, long t) {
|
||||
long now = System.nanoTime();
|
||||
infoStream.message(
|
||||
HNSW_COMPONENT,
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"built %d in %d/%d ms",
|
||||
node,
|
||||
((now - t) / 1_000_000),
|
||||
((now - start) / 1_000_000)));
|
||||
return now;
|
||||
}
|
||||
|
||||
/* TODO: we are not maintaining nodes in strict score order; the forward links
|
||||
* are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should
|
||||
* work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap?
|
||||
* But first we should just see if sorting makes a significant difference.
|
||||
*/
|
||||
private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
|
||||
throws IOException {
|
||||
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
|
||||
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
||||
* since the node is new and has no prior neighbors).
|
||||
*/
|
||||
Lucene91NeighborArray neighbors = hnsw.getNeighbors(level, node);
|
||||
assert neighbors.size() == 0; // new node
|
||||
popToScratch(candidates);
|
||||
selectDiverse(neighbors, scratch);
|
||||
|
||||
// Link the selected nodes to the new node, and the new node to the selected nodes (again
|
||||
// applying diversity heuristic)
|
||||
int size = neighbors.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
int nbr = neighbors.node[i];
|
||||
Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
|
||||
nbrNbr.add(node, neighbors.score[i]);
|
||||
if (nbrNbr.size() > maxConn) {
|
||||
diversityUpdate(nbrNbr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void selectDiverse(Lucene91NeighborArray neighbors, Lucene91NeighborArray candidates)
|
||||
throws IOException {
|
||||
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
||||
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
||||
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
||||
// only adding it if it is closer to the target than to any of the other selected neighbors
|
||||
int cNode = candidates.node[i];
|
||||
float cScore = candidates.score[i];
|
||||
assert cNode < hnsw.size();
|
||||
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
|
||||
neighbors.add(cNode, cScore);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void popToScratch(NeighborQueue candidates) {
|
||||
scratch.clear();
|
||||
int candidateCount = candidates.size();
|
||||
// extract all the Neighbors from the queue into an array; these will now be
|
||||
// sorted from worst to best
|
||||
for (int i = 0; i < candidateCount; i++) {
|
||||
float score = candidates.topScore();
|
||||
scratch.add(candidates.pop(), score);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param candidate the vector of a new candidate neighbor of a node n
|
||||
* @param score the score of the new candidate and node n, to be compared with scores of the
|
||||
* candidate and n's neighbors
|
||||
* @param neighbors the neighbors selected so far
|
||||
* @param vectorValues source of values used for making comparisons between candidate and existing
|
||||
* neighbors
|
||||
* @return whether the candidate is diverse given the existing neighbors
|
||||
*/
|
||||
private boolean diversityCheck(
|
||||
float[] candidate,
|
||||
float score,
|
||||
Lucene91NeighborArray neighbors,
|
||||
RandomAccessVectorValues vectorValues)
|
||||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float diversityCheck =
|
||||
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void diversityUpdate(Lucene91NeighborArray neighbors) throws IOException {
|
||||
assert neighbors.size() == maxConn + 1;
|
||||
int replacePoint = findNonDiverse(neighbors);
|
||||
if (replacePoint == -1) {
|
||||
// none found; check score against worst existing neighbor
|
||||
bound.set(neighbors.score[0]);
|
||||
if (bound.check(neighbors.score[maxConn])) {
|
||||
// drop the new neighbor; it is not competitive and there were no diversity failures
|
||||
neighbors.removeLast();
|
||||
return;
|
||||
} else {
|
||||
replacePoint = 0;
|
||||
}
|
||||
}
|
||||
neighbors.node[replacePoint] = neighbors.node[maxConn];
|
||||
neighbors.score[replacePoint] = neighbors.score[maxConn];
|
||||
neighbors.removeLast();
|
||||
}
|
||||
|
||||
// scan neighbors looking for diversity violations
|
||||
private int findNonDiverse(Lucene91NeighborArray neighbors) throws IOException {
|
||||
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
||||
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
||||
// them, drop it
|
||||
int nbrNode = neighbors.node[i];
|
||||
bound.set(neighbors.score[i]);
|
||||
float[] nbrVector = vectorValues.vectorValue(nbrNode);
|
||||
for (int j = maxConn; j > i; j--) {
|
||||
float diversityCheck =
|
||||
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
// node j is too similar to node i given its score relative to the base node
|
||||
// replace it with the new node, which is at [maxConn]
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
||||
double randDouble;
|
||||
do {
|
||||
randDouble = random.nextDouble(); // avoid 0 value, as log(0) is undefined
|
||||
} while (randDouble == 0.0);
|
||||
return ((int) (-log(randDouble) * ml));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
/**
|
||||
* NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
|
||||
* of growable arrays.
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public class Lucene91NeighborArray {
|
||||
|
||||
private int size;
|
||||
|
||||
float[] score;
|
||||
int[] node;
|
||||
|
||||
/** Create a neighbour array with the given initial size */
|
||||
public Lucene91NeighborArray(int maxSize) {
|
||||
node = new int[maxSize];
|
||||
score = new float[maxSize];
|
||||
}
|
||||
|
||||
/** Add a new node with a score */
|
||||
public void add(int newNode, float newScore) {
|
||||
if (size == node.length - 1) {
|
||||
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
||||
score = ArrayUtil.growExact(score, node.length);
|
||||
}
|
||||
node[size] = newNode;
|
||||
score[size] = newScore;
|
||||
++size;
|
||||
}
|
||||
|
||||
/** Get the size, the number of nodes added so far */
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct access to the internal list of node ids; provided for efficient writing of the graph
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public int[] node() {
|
||||
return node;
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct access to the internal list of scores
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public float[] score() {
|
||||
return score;
|
||||
}
|
||||
|
||||
/** Clear all the nodes in the array */
|
||||
public void clear() {
|
||||
size = 0;
|
||||
}
|
||||
|
||||
/** Remove the last nodes from the array */
|
||||
public void removeLast() {
|
||||
size--;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "NeighborArray[" + size + "]";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||
* construct the HNSW graph before it's written to the index.
|
||||
*/
|
||||
public final class Lucene91OnHeapHnswGraph extends HnswGraph {
|
||||
|
||||
private final int maxConn;
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
|
||||
// Nodes by level expressed as the level 0's nodes' ordinals.
|
||||
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
|
||||
private final List<int[]> nodesByLevel;
|
||||
|
||||
// graph is a list of graph levels.
|
||||
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<List<Lucene91NeighborArray>> graph;
|
||||
|
||||
// KnnGraphValues iterator members
|
||||
private int upto;
|
||||
private Lucene91NeighborArray cur;
|
||||
|
||||
Lucene91OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
|
||||
this.maxConn = maxConn;
|
||||
this.numLevels = levelOfFirstNode + 1;
|
||||
this.graph = new ArrayList<>(numLevels);
|
||||
this.entryNode = 0;
|
||||
for (int i = 0; i < numLevels; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
// Typically with diversity criteria we see nodes not fully occupied;
|
||||
// average fanout seems to be about 1/2 maxConn.
|
||||
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||
graph.get(i).add(new Lucene91NeighborArray(Math.max(32, maxConn / 4)));
|
||||
}
|
||||
|
||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||
for (int l = 1; l < numLevels; l++) {
|
||||
nodesByLevel.add(new int[] {0});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link NeighborQueue} connected to the given node.
|
||||
*
|
||||
* @param level level of the graph
|
||||
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public Lucene91NeighborArray getNeighbors(int level, int node) {
|
||||
if (level == 0) {
|
||||
return graph.get(level).get(node);
|
||||
}
|
||||
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
|
||||
assert nodeIndex >= 0;
|
||||
return graph.get(level).get(nodeIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graph.get(0).size(); // all nodes are located on the 0th level
|
||||
}
|
||||
|
||||
/**
|
||||
* Add node on the given level
|
||||
*
|
||||
* @param level level to add a node on
|
||||
* @param node the node to add, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public void addNode(int level, int node) {
|
||||
if (level > 0) {
|
||||
// if the new node introduces a new level, add more levels to the graph,
|
||||
// and make this node the graph's new entry point
|
||||
if (level >= numLevels) {
|
||||
for (int i = numLevels; i <= level; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
nodesByLevel.add(new int[] {node});
|
||||
}
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
} else {
|
||||
// Add this node id to this level's nodes
|
||||
int[] nodes = nodesByLevel.get(level);
|
||||
int idx = graph.get(level).size();
|
||||
if (idx < nodes.length) {
|
||||
nodes[idx] = node;
|
||||
} else {
|
||||
nodes = ArrayUtil.grow(nodes);
|
||||
nodes[idx] = node;
|
||||
nodesByLevel.set(level, nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph.get(level).add(new Lucene91NeighborArray(maxConn + 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int targetNode) {
|
||||
cur = getNeighbors(level, targetNode);
|
||||
upto = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
if (++upto < cur.size()) {
|
||||
return cur.node[upto];
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current number of levels in the graph
|
||||
*
|
||||
* @return the current number of levels in the graph
|
||||
*/
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return numLevels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
|
||||
* level
|
||||
*
|
||||
* @return the graph's current entry node on the top level
|
||||
*/
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return entryNode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -37,9 +37,6 @@ import org.apache.lucene.store.IndexOutput;
|
|||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
|
||||
/**
|
||||
* Writes vector values and knn graphs to index segments.
|
||||
|
@ -145,7 +142,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
|
||||
new Lucene91HnswVectorsReader.OffHeapVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), null, vectorDataInput);
|
||||
OnHeapHnswGraph graph =
|
||||
Lucene91OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
||||
|
@ -194,7 +191,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
long vectorIndexOffset,
|
||||
long vectorIndexLength,
|
||||
DocsWithFieldSet docsWithField,
|
||||
OnHeapHnswGraph graph)
|
||||
Lucene91OnHeapHnswGraph graph)
|
||||
throws IOException {
|
||||
meta.writeInt(field.number);
|
||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||
|
@ -236,16 +233,20 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
private Lucene91OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
|
||||
Lucene91HnswGraphBuilder hnswGraphBuilder =
|
||||
new Lucene91HnswGraphBuilder(
|
||||
vectorValues,
|
||||
similarityFunction,
|
||||
maxConn,
|
||||
beamWidth,
|
||||
Lucene91HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
|
@ -253,7 +254,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
|
|||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeInt(size);
|
||||
// Destructively modify; it's ok we are discarding it after this
|
||||
|
|
|
@ -58,9 +58,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
|||
* <ul>
|
||||
* <li><b>[int32]</b> the number of neighbor nodes
|
||||
* <li><b>array[int32]</b> the neighbor ordinals
|
||||
* <li><b>array[int32]</b> padding from empty integers if the number of neighbors less
|
||||
* than the maximum number of connections (maxConn). Padding is equal to
|
||||
* ((maxConn-the number of neighbours) * 4) bytes.
|
||||
* <li><b>array[int32]</b> padding if the number of the node's neighbors is less than
|
||||
* the maximum number of connections allowed on this level. Padding is equal to
|
||||
* ((maxConnOnLevel – the number of neighbours) * 4) bytes.
|
||||
* </ul>
|
||||
* </ul>
|
||||
* </ul>
|
||||
|
|
|
@ -282,7 +282,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
final long vectorDataLength;
|
||||
final long vectorIndexOffset;
|
||||
final long vectorIndexLength;
|
||||
final int maxConn;
|
||||
final int M;
|
||||
final int numLevels;
|
||||
final int dimension;
|
||||
final int size;
|
||||
|
@ -336,7 +336,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
// read nodes by level
|
||||
maxConn = input.readInt();
|
||||
M = input.readInt();
|
||||
numLevels = input.readInt();
|
||||
nodesByLevel = new int[numLevels][];
|
||||
for (int level = 0; level < numLevels; level++) {
|
||||
|
@ -359,10 +359,13 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
for (int level = 0; level < numLevels; level++) {
|
||||
if (level == 0) {
|
||||
graphOffsetsByLevel[level] = 0;
|
||||
} else if (level == 1) {
|
||||
int numNodesOnLevel0 = size;
|
||||
graphOffsetsByLevel[level] = (1 + (M * 2)) * Integer.BYTES * numNodesOnLevel0;
|
||||
} else {
|
||||
int numNodesOnPrevLevel = level == 1 ? size : nodesByLevel[level - 1].length;
|
||||
int numNodesOnPrevLevel = nodesByLevel[level - 1].length;
|
||||
graphOffsetsByLevel[level] =
|
||||
graphOffsetsByLevel[level - 1] + (1 + maxConn) * Integer.BYTES * numNodesOnPrevLevel;
|
||||
graphOffsetsByLevel[level - 1] + (1 + M) * Integer.BYTES * numNodesOnPrevLevel;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -382,6 +385,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
final int entryNode;
|
||||
final int size;
|
||||
final long bytesForConns;
|
||||
final long bytesForConns0;
|
||||
|
||||
int arcCount;
|
||||
int arcUpTo;
|
||||
|
@ -394,7 +398,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0;
|
||||
this.size = entry.size();
|
||||
this.graphOffsetsByLevel = entry.graphOffsetsByLevel;
|
||||
this.bytesForConns = ((long) entry.maxConn + 1) * Integer.BYTES;
|
||||
this.bytesForConns = ((long) entry.M + 1) * Integer.BYTES;
|
||||
this.bytesForConns0 = ((long) (entry.M * 2) + 1) * Integer.BYTES;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -404,7 +409,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
? targetOrd
|
||||
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
|
||||
assert targetIndex >= 0;
|
||||
long graphDataOffset = graphOffsetsByLevel[level] + targetIndex * bytesForConns;
|
||||
long graphDataOffset =
|
||||
graphOffsetsByLevel[level] + targetIndex * (level == 0 ? bytesForConns0 : bytesForConns);
|
||||
// unsafe; no bounds checking
|
||||
dataIn.seek(graphDataOffset);
|
||||
arcCount = dataIn.readInt();
|
||||
|
|
|
@ -55,13 +55,12 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
private final int maxDoc;
|
||||
|
||||
private final int maxConn;
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
private boolean finished;
|
||||
|
||||
Lucene92HnswVectorsWriter(SegmentWriteState state, int maxConn, int beamWidth)
|
||||
throws IOException {
|
||||
this.maxConn = maxConn;
|
||||
Lucene92HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
|
||||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
|
@ -248,7 +247,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
meta.writeLong(vectorData.getFilePointer() - start);
|
||||
}
|
||||
|
||||
meta.writeInt(maxConn);
|
||||
meta.writeInt(M);
|
||||
// write graph nodes on each level
|
||||
if (graph == null) {
|
||||
meta.writeInt(0);
|
||||
|
@ -274,13 +273,14 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
// build graph
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
|
||||
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? (M * 2) : M;
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
|
@ -297,7 +297,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
|
||||
// offsets
|
||||
for (int i = size; i < maxConn; i++) {
|
||||
for (int i = size; i < maxConnOnLevel; i++) {
|
||||
vectorIndex.writeInt(0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ public final class HnswGraphBuilder {
|
|||
/** Random seed for level generation; public to expose for testing * */
|
||||
public static long randSeed = DEFAULT_RAND_SEED;
|
||||
|
||||
private final int maxConn;
|
||||
private final int M; // max number of connections on upper layers
|
||||
private final int beamWidth;
|
||||
private final double ml;
|
||||
private final NeighborArray scratch;
|
||||
|
@ -68,8 +68,8 @@ public final class HnswGraphBuilder {
|
|||
*
|
||||
* @param vectors the vectors whose relations are represented by the graph - must provide a
|
||||
* different view over those vectors than the one used to add via addGraphNode.
|
||||
* @param maxConn the number of connections to make when adding a new graph node; roughly speaking
|
||||
* the graph fanout.
|
||||
* @param M – graph fanout parameter used to calculate the maximum number of connections a node
|
||||
* can have – M on upper layers, and M * 2 on the lowest level.
|
||||
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||
* to ensure repeatable construction.
|
||||
|
@ -77,26 +77,26 @@ public final class HnswGraphBuilder {
|
|||
public HnswGraphBuilder(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed)
|
||||
throws IOException {
|
||||
vectorValues = vectors.randomAccess();
|
||||
buildVectors = vectors.randomAccess();
|
||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||
if (maxConn <= 0) {
|
||||
if (M <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
}
|
||||
if (beamWidth <= 0) {
|
||||
throw new IllegalArgumentException("beamWidth must be positive");
|
||||
}
|
||||
this.maxConn = maxConn;
|
||||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
// normalization factor for level generation; currently not configurable
|
||||
this.ml = 1 / Math.log(1.0 * maxConn);
|
||||
this.ml = 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed);
|
||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode, similarityFunction.reversed);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
similarityFunction,
|
||||
|
@ -104,7 +104,7 @@ public final class HnswGraphBuilder {
|
|||
new FixedBitSet(vectorValues.size()));
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
// in scratch we store candidates in reverse order: worse candidates are first
|
||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed);
|
||||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), similarityFunction.reversed);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -187,7 +187,8 @@ public final class HnswGraphBuilder {
|
|||
NeighborArray neighbors = hnsw.getNeighbors(level, node);
|
||||
assert neighbors.size() == 0; // new node
|
||||
popToScratch(candidates);
|
||||
selectAndLinkDiverse(neighbors, scratch);
|
||||
int maxConnOnLevel = level == 0 ? M * 2 : M;
|
||||
selectAndLinkDiverse(neighbors, scratch, maxConnOnLevel);
|
||||
|
||||
// Link the selected nodes to the new node, and the new node to the selected nodes (again
|
||||
// applying diversity heuristic)
|
||||
|
@ -196,17 +197,17 @@ public final class HnswGraphBuilder {
|
|||
int nbr = neighbors.node[i];
|
||||
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
|
||||
nbrNbr.insertSorted(node, neighbors.score[i]);
|
||||
if (nbrNbr.size() > maxConn) {
|
||||
if (nbrNbr.size() > maxConnOnLevel) {
|
||||
int indexToRemove = findWorstNonDiverse(nbrNbr);
|
||||
nbrNbr.removeIndex(indexToRemove);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates)
|
||||
throws IOException {
|
||||
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
||||
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
||||
private void selectAndLinkDiverse(
|
||||
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
|
||||
// Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
|
||||
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
|
||||
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
||||
// only adding it if it is closer to the target than to any of the other selected neighbors
|
||||
int cNode = candidates.node[i];
|
||||
|
|
|
@ -47,7 +47,7 @@ public final class HnswGraphSearcher {
|
|||
* @param candidates max heap that will track the candidate nodes to explore
|
||||
* @param visited bit set that will track nodes that have already been visited
|
||||
*/
|
||||
HnswGraphSearcher(
|
||||
public HnswGraphSearcher(
|
||||
VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) {
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.candidates = candidates;
|
||||
|
@ -112,7 +112,7 @@ public final class HnswGraphSearcher {
|
|||
* @param graph the graph values
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
NeighborQueue searchLevel(
|
||||
public NeighborQueue searchLevel(
|
||||
float[] query,
|
||||
int topK,
|
||||
int level,
|
||||
|
|
|
@ -98,7 +98,7 @@ public class NeighborQueue {
|
|||
return (int) order.apply(heap.pop());
|
||||
}
|
||||
|
||||
int[] nodes() {
|
||||
public int[] nodes() {
|
||||
int size = size();
|
||||
int[] nodes = new int[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.util.ArrayUtil;
|
|||
*/
|
||||
public final class OnHeapHnswGraph extends HnswGraph {
|
||||
|
||||
private final int maxConn;
|
||||
private final boolean similarityReversed;
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
|
@ -41,27 +40,30 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
|||
|
||||
// graph is a list of graph levels.
|
||||
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
|
||||
// Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
|
||||
// to vectors
|
||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<List<NeighborArray>> graph;
|
||||
private final int nsize;
|
||||
private final int nsize0;
|
||||
|
||||
// KnnGraphValues iterator members
|
||||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) {
|
||||
this.maxConn = maxConn;
|
||||
OnHeapHnswGraph(int M, int levelOfFirstNode, boolean similarityReversed) {
|
||||
this.similarityReversed = similarityReversed;
|
||||
this.numLevels = levelOfFirstNode + 1;
|
||||
this.graph = new ArrayList<>(numLevels);
|
||||
this.entryNode = 0;
|
||||
for (int i = 0; i < numLevels; i++) {
|
||||
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
|
||||
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
|
||||
this.nsize = M + 1;
|
||||
this.nsize0 = (M * 2 + 1);
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
graph.add(new ArrayList<>());
|
||||
// Typically with diversity criteria we see nodes not fully occupied;
|
||||
// average fanout seems to be about 1/2 maxConn.
|
||||
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false));
|
||||
graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, similarityReversed == false));
|
||||
}
|
||||
|
||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||
|
@ -121,8 +123,9 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false));
|
||||
graph
|
||||
.get(level)
|
||||
.add(new NeighborArray(level == 0 ? nsize0 : nsize, similarityReversed == false));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -64,7 +64,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
private static final String KNN_GRAPH_FIELD = "vector";
|
||||
|
||||
private static int maxConn = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
private static int M = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
|
||||
private Codec codec;
|
||||
private VectorSimilarityFunction similarityFunction;
|
||||
|
@ -73,15 +73,14 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
public void setup() {
|
||||
randSeed = random().nextLong();
|
||||
if (random().nextBoolean()) {
|
||||
maxConn = random().nextInt(256) + 3;
|
||||
M = random().nextInt(256) + 3;
|
||||
}
|
||||
|
||||
codec =
|
||||
new Lucene92Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene92HnswVectorsFormat(
|
||||
maxConn, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
return new Lucene92HnswVectorsFormat(M, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -91,7 +90,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
@After
|
||||
public void cleanup() {
|
||||
maxConn = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
M = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
}
|
||||
|
||||
/** Basic test of creating documents in a graph */
|
||||
|
@ -263,7 +262,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
||||
int[][][] graph = new int[graphValues.numLevels()][][];
|
||||
int size = graphValues.size();
|
||||
int[] scratch = new int[maxConn];
|
||||
int[] scratch = new int[M * 2];
|
||||
|
||||
for (int level = 0; level < graphValues.numLevels(); level++) {
|
||||
NodesIterator nodesItr = graphValues.getNodesOnLevel(level);
|
||||
|
@ -483,10 +482,13 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
// For each level of the graph assert that:
|
||||
// 1. There are no orphan nodes without any friends
|
||||
// 2. If orphans are found, than the level must contain only 0 or a single node
|
||||
// 3. If the number of nodes on the level doesn't exceed maxConn, assert that the graph is
|
||||
// 3. If the number of nodes on the level doesn't exceed maxConnOnLevel, assert that the
|
||||
// graph is
|
||||
// fully connected, i.e. any node is reachable from any other node.
|
||||
// 4. If the number of nodes on the level exceeds maxConn, assert that maxConn is respected.
|
||||
// 4. If the number of nodes on the level exceeds maxConnOnLevel, assert that maxConnOnLevel
|
||||
// is respected.
|
||||
for (int level = 0; level < graphValues.numLevels(); level++) {
|
||||
int maxConnOnLevel = level == 0 ? M * 2 : M;
|
||||
int[][] graphOnLevel = new int[graphValues.size()][];
|
||||
int countOnLevel = 0;
|
||||
boolean foundOrphan = false;
|
||||
|
@ -508,7 +510,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
countOnLevel++;
|
||||
}
|
||||
// System.out.println("Level[" + level + "] has [" + nodesCount + "] nodes.");
|
||||
assertEquals(nodesItr.size(), countOnLevel);
|
||||
assertFalse("No nodes on level [" + level + "]", countOnLevel == 0);
|
||||
if (countOnLevel == 1) {
|
||||
|
@ -517,13 +518,13 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
} else {
|
||||
assertFalse(
|
||||
"Graph has orphan nodes with no friends on level [" + level + "]", foundOrphan);
|
||||
if (maxConn > countOnLevel) {
|
||||
if (maxConnOnLevel > countOnLevel) {
|
||||
// assert that the graph is fully connected,
|
||||
// i.e. any node can be reached from any other node
|
||||
assertConnected(graphOnLevel);
|
||||
} else {
|
||||
// assert that max-connections was respected
|
||||
assertMaxConn(graphOnLevel, maxConn);
|
||||
assertMaxConn(graphOnLevel, maxConnOnLevel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,14 +62,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
|
||||
int maxConn = random().nextInt(10) + 5;
|
||||
int M = random().nextInt(10) + 5;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed);
|
||||
new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed);
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
// Recreate the graph while indexing with the same random seed and write it out
|
||||
|
@ -84,7 +84,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
new Lucene92Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene92HnswVectorsFormat(maxConn, beamWidth);
|
||||
return new Lucene92HnswVectorsFormat(M, beamWidth);
|
||||
}
|
||||
});
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
|
@ -153,12 +153,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
// ensuring that we have all the distance functions, comparators, priority queues and so on
|
||||
// oriented in the right directions
|
||||
public void testAknnDiverse() throws IOException {
|
||||
int maxConn = 10;
|
||||
int nDoc = 100;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
|
@ -193,11 +192,10 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
public void testSearchWithAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
int maxConn = 16;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// the first 10 docs must not be deleted to ensure the expected recall
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||
|
@ -224,11 +222,10 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
int maxConn = 16;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// Only mark a few vectors as accepted
|
||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
||||
|
@ -290,11 +287,10 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
int maxConn = 16;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
int topK = 50;
|
||||
|
@ -396,9 +392,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
builder.addGraphNode(4, vectors.vectorValue(4));
|
||||
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
// 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
|
||||
// replace it
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 4);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
// 1 survives the diversity check
|
||||
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
|
||||
|
@ -406,11 +400,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
builder.addGraphNode(5, vectors.vectorValue(5));
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 5);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
// even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs
|
||||
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
|
||||
assertLevel0Neighbors(builder.hnsw, 4, 3, 5);
|
||||
assertLevel0Neighbors(builder.hnsw, 4, 1, 3, 5);
|
||||
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
|
||||
}
|
||||
|
||||
|
@ -428,14 +422,13 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
public void testRandom() throws IOException {
|
||||
int size = atLeast(100);
|
||||
int dim = atLeast(10);
|
||||
int maxConn = 10;
|
||||
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
int topK = 5;
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
|
||||
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||
|
||||
|
|
Loading…
Reference in New Issue