diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 90e2382ece9..7195b9a2697 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -115,6 +115,8 @@ Improvements * LUCENE-9848: Correctly sort HNSW graph neighbors when applying diversity criterion (Mayya Sharipova, Michael Sokolov) +* LUCENE-10527: Use 2*maxConn for the last layer in HNSW (Mayya Sharipova) + Optimizations --------------------- diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java new file mode 100644 index 00000000000..002497d2d2a --- /dev/null +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.backward_codecs.lucene91; + +import static java.lang.Math.log; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; +import java.util.SplittableRandom; +import org.apache.lucene.index.RandomAccessVectorValues; +import org.apache.lucene.index.RandomAccessVectorValuesProducer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.hnsw.BoundsChecker; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.NeighborQueue; + +/** + * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the + * hyperparameters. + */ +public final class Lucene91HnswGraphBuilder { + + /** Default random seed for level generation * */ + private static final long DEFAULT_RAND_SEED = 42; + /** A name for the HNSW component for the info-stream * */ + public static final String HNSW_COMPONENT = "HNSW"; + + /** Random seed for level generation; public to expose for testing * */ + public static long randSeed = DEFAULT_RAND_SEED; + + private final int maxConn; + private final int beamWidth; + private final double ml; + private final Lucene91NeighborArray scratch; + + private final VectorSimilarityFunction similarityFunction; + private final RandomAccessVectorValues vectorValues; + private final SplittableRandom random; + private final BoundsChecker bound; + private final HnswGraphSearcher graphSearcher; + + final Lucene91OnHeapHnswGraph hnsw; + + private InfoStream infoStream = InfoStream.getDefault(); + + // we need two sources of vectors in order to perform diversity check comparisons without + // colliding + private RandomAccessVectorValues buildVectors; + + /** + * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectors the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param maxConn the number of connections to make when adding a new graph node; roughly speaking + * the graph fanout. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param seed the seed for a random number generator used during graph construction. Provide this + * to ensure repeatable construction. + */ + public Lucene91HnswGraphBuilder( + RandomAccessVectorValuesProducer vectors, + VectorSimilarityFunction similarityFunction, + int maxConn, + int beamWidth, + long seed) + throws IOException { + vectorValues = vectors.randomAccess(); + buildVectors = vectors.randomAccess(); + this.similarityFunction = Objects.requireNonNull(similarityFunction); + if (maxConn <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + // normalization factor for level generation; currently not configurable + this.ml = 1 / Math.log(1.0 * maxConn); + this.random = new SplittableRandom(seed); + int levelOfFirstNode = getRandomGraphLevel(ml, random); + this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode); + this.graphSearcher = + new HnswGraphSearcher( + similarityFunction, + new NeighborQueue(beamWidth, similarityFunction.reversed == false), + new FixedBitSet(vectorValues.size())); + bound = BoundsChecker.create(similarityFunction.reversed); + scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1)); + } + + /** + * Reads all the vectors from two copies of a random access VectorValues. Providing two copies + * enables efficient retrieval without extra data copying, while avoiding collision of the + * returned values. + * + * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet + * accessor for the vectors + */ + public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException { + if (vectors == vectorValues) { + throw new IllegalArgumentException( + "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors"); + } + long start = System.nanoTime(), t = start; + // start at node 1! node 0 is added implicitly, in the constructor + for (int node = 1; node < vectors.size(); node++) { + addGraphNode(node, vectors.vectorValue(node)); + if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { + t = printGraphBuildStatus(node, start, t); + } + } + return hnsw; + } + + /** Set info-stream to output debugging information * */ + public void setInfoStream(InfoStream infoStream) { + this.infoStream = infoStream; + } + + /** Inserts a doc with vector value to the graph */ + void addGraphNode(int node, float[] value) throws IOException { + NeighborQueue candidates; + final int nodeLevel = getRandomGraphLevel(ml, random); + int curMaxLevel = hnsw.numLevels() - 1; + int[] eps = new int[] {hnsw.entryNode()}; + + // if a node introduces new levels to the graph, add this new node on new levels + for (int level = nodeLevel; level > curMaxLevel; level--) { + hnsw.addNode(level, node); + } + + // for levels > nodeLevel search with topk = 1 + for (int level = curMaxLevel; level > nodeLevel; level--) { + candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw); + eps = new int[] {candidates.pop()}; + } + // for levels <= nodeLevel search with topk = beamWidth, and add connections + for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { + candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw); + eps = candidates.nodes(); + hnsw.addNode(level, node); + addDiverseNeighbors(level, node, candidates); + } + } + + private long printGraphBuildStatus(int node, long start, long t) { + long now = System.nanoTime(); + infoStream.message( + HNSW_COMPONENT, + String.format( + Locale.ROOT, + "built %d in %d/%d ms", + node, + ((now - t) / 1_000_000), + ((now - start) / 1_000_000))); + return now; + } + + /* TODO: we are not maintaining nodes in strict score order; the forward links + * are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should + * work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap? + * But first we should just see if sorting makes a significant difference. + */ + private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) + throws IOException { + /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it + * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, + * since the node is new and has no prior neighbors). + */ + Lucene91NeighborArray neighbors = hnsw.getNeighbors(level, node); + assert neighbors.size() == 0; // new node + popToScratch(candidates); + selectDiverse(neighbors, scratch); + + // Link the selected nodes to the new node, and the new node to the selected nodes (again + // applying diversity heuristic) + int size = neighbors.size(); + for (int i = 0; i < size; i++) { + int nbr = neighbors.node[i]; + Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); + nbrNbr.add(node, neighbors.score[i]); + if (nbrNbr.size() > maxConn) { + diversityUpdate(nbrNbr); + } + } + } + + private void selectDiverse(Lucene91NeighborArray neighbors, Lucene91NeighborArray candidates) + throws IOException { + // Select the best maxConn neighbors of the new node, applying the diversity heuristic + for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) { + // compare each neighbor (in distance order) against the closer neighbors selected so far, + // only adding it if it is closer to the target than to any of the other selected neighbors + int cNode = candidates.node[i]; + float cScore = candidates.score[i]; + assert cNode < hnsw.size(); + if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) { + neighbors.add(cNode, cScore); + } + } + } + + private void popToScratch(NeighborQueue candidates) { + scratch.clear(); + int candidateCount = candidates.size(); + // extract all the Neighbors from the queue into an array; these will now be + // sorted from worst to best + for (int i = 0; i < candidateCount; i++) { + float score = candidates.topScore(); + scratch.add(candidates.pop(), score); + } + } + + /** + * @param candidate the vector of a new candidate neighbor of a node n + * @param score the score of the new candidate and node n, to be compared with scores of the + * candidate and n's neighbors + * @param neighbors the neighbors selected so far + * @param vectorValues source of values used for making comparisons between candidate and existing + * neighbors + * @return whether the candidate is diverse given the existing neighbors + */ + private boolean diversityCheck( + float[] candidate, + float score, + Lucene91NeighborArray neighbors, + RandomAccessVectorValues vectorValues) + throws IOException { + bound.set(score); + for (int i = 0; i < neighbors.size(); i++) { + float diversityCheck = + similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i])); + if (bound.check(diversityCheck) == false) { + return false; + } + } + return true; + } + + private void diversityUpdate(Lucene91NeighborArray neighbors) throws IOException { + assert neighbors.size() == maxConn + 1; + int replacePoint = findNonDiverse(neighbors); + if (replacePoint == -1) { + // none found; check score against worst existing neighbor + bound.set(neighbors.score[0]); + if (bound.check(neighbors.score[maxConn])) { + // drop the new neighbor; it is not competitive and there were no diversity failures + neighbors.removeLast(); + return; + } else { + replacePoint = 0; + } + } + neighbors.node[replacePoint] = neighbors.node[maxConn]; + neighbors.score[replacePoint] = neighbors.score[maxConn]; + neighbors.removeLast(); + } + + // scan neighbors looking for diversity violations + private int findNonDiverse(Lucene91NeighborArray neighbors) throws IOException { + for (int i = neighbors.size() - 1; i >= 0; i--) { + // check each neighbor against its better-scoring neighbors. If it fails diversity check with + // them, drop it + int nbrNode = neighbors.node[i]; + bound.set(neighbors.score[i]); + float[] nbrVector = vectorValues.vectorValue(nbrNode); + for (int j = maxConn; j > i; j--) { + float diversityCheck = + similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j])); + if (bound.check(diversityCheck) == false) { + // node j is too similar to node i given its score relative to the base node + // replace it with the new node, which is at [maxConn] + return i; + } + } + } + return -1; + } + + private static int getRandomGraphLevel(double ml, SplittableRandom random) { + double randDouble; + do { + randDouble = random.nextDouble(); // avoid 0 value, as log(0) is undefined + } while (randDouble == 0.0); + return ((int) (-log(randDouble) * ml)); + } +} diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91NeighborArray.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91NeighborArray.java new file mode 100644 index 00000000000..fcb097162f1 --- /dev/null +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91NeighborArray.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.backward_codecs.lucene91; + +import org.apache.lucene.util.ArrayUtil; + +/** + * NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair + * of growable arrays. + * + * @lucene.internal + */ +public class Lucene91NeighborArray { + + private int size; + + float[] score; + int[] node; + + /** Create a neighbour array with the given initial size */ + public Lucene91NeighborArray(int maxSize) { + node = new int[maxSize]; + score = new float[maxSize]; + } + + /** Add a new node with a score */ + public void add(int newNode, float newScore) { + if (size == node.length - 1) { + node = ArrayUtil.grow(node, (size + 1) * 3 / 2); + score = ArrayUtil.growExact(score, node.length); + } + node[size] = newNode; + score[size] = newScore; + ++size; + } + + /** Get the size, the number of nodes added so far */ + public int size() { + return size; + } + + /** + * Direct access to the internal list of node ids; provided for efficient writing of the graph + * + * @lucene.internal + */ + public int[] node() { + return node; + } + + /** + * Direct access to the internal list of scores + * + * @lucene.internal + */ + public float[] score() { + return score; + } + + /** Clear all the nodes in the array */ + public void clear() { + size = 0; + } + + /** Remove the last nodes from the array */ + public void removeLast() { + size--; + } + + @Override + public String toString() { + return "NeighborArray[" + size + "]"; + } +} diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java new file mode 100644 index 00000000000..2d3ef582b47 --- /dev/null +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.backward_codecs.lucene91; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.NeighborQueue; + +/** + * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to + * construct the HNSW graph before it's written to the index. + */ +public final class Lucene91OnHeapHnswGraph extends HnswGraph { + + private final int maxConn; + private int numLevels; // the current number of levels in the graph + private int entryNode; // the current graph entry node on the top level + + // Nodes by level expressed as the level 0's nodes' ordinals. + // As level 0 contains all nodes, nodesByLevel.get(0) is null. + private final List nodesByLevel; + + // graph is a list of graph levels. + // Each level is represented as List – nodes' connections on this level. + // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors + // added to HnswBuilder, and the node values are the ordinals of those vectors. + // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. + private final List> graph; + + // KnnGraphValues iterator members + private int upto; + private Lucene91NeighborArray cur; + + Lucene91OnHeapHnswGraph(int maxConn, int levelOfFirstNode) { + this.maxConn = maxConn; + this.numLevels = levelOfFirstNode + 1; + this.graph = new ArrayList<>(numLevels); + this.entryNode = 0; + for (int i = 0; i < numLevels; i++) { + graph.add(new ArrayList<>()); + // Typically with diversity criteria we see nodes not fully occupied; + // average fanout seems to be about 1/2 maxConn. + // There is some indexing time penalty for under-allocating, but saves RAM + graph.get(i).add(new Lucene91NeighborArray(Math.max(32, maxConn / 4))); + } + + this.nodesByLevel = new ArrayList<>(numLevels); + nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes + for (int l = 1; l < numLevels; l++) { + nodesByLevel.add(new int[] {0}); + } + } + + /** + * Returns the {@link NeighborQueue} connected to the given node. + * + * @param level level of the graph + * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. + */ + public Lucene91NeighborArray getNeighbors(int level, int node) { + if (level == 0) { + return graph.get(level).get(node); + } + int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node); + assert nodeIndex >= 0; + return graph.get(level).get(nodeIndex); + } + + @Override + public int size() { + return graph.get(0).size(); // all nodes are located on the 0th level + } + + /** + * Add node on the given level + * + * @param level level to add a node on + * @param node the node to add, represented as an ordinal on the level 0. + */ + public void addNode(int level, int node) { + if (level > 0) { + // if the new node introduces a new level, add more levels to the graph, + // and make this node the graph's new entry point + if (level >= numLevels) { + for (int i = numLevels; i <= level; i++) { + graph.add(new ArrayList<>()); + nodesByLevel.add(new int[] {node}); + } + numLevels = level + 1; + entryNode = node; + } else { + // Add this node id to this level's nodes + int[] nodes = nodesByLevel.get(level); + int idx = graph.get(level).size(); + if (idx < nodes.length) { + nodes[idx] = node; + } else { + nodes = ArrayUtil.grow(nodes); + nodes[idx] = node; + nodesByLevel.set(level, nodes); + } + } + } + + graph.get(level).add(new Lucene91NeighborArray(maxConn + 1)); + } + + @Override + public void seek(int level, int targetNode) { + cur = getNeighbors(level, targetNode); + upto = -1; + } + + @Override + public int nextNeighbor() { + if (++upto < cur.size()) { + return cur.node[upto]; + } + return NO_MORE_DOCS; + } + + /** + * Returns the current number of levels in the graph + * + * @return the current number of levels in the graph + */ + @Override + public int numLevels() { + return numLevels; + } + + /** + * Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th + * level + * + * @return the graph's current entry node on the top level + */ + @Override + public int entryNode() { + return entryNode; + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + if (level == 0) { + return new NodesIterator(size()); + } else { + return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + } + } +} diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 6e1527541b5..0542163057e 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -37,9 +37,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; -import org.apache.lucene.util.hnsw.HnswGraphBuilder; -import org.apache.lucene.util.hnsw.NeighborArray; -import org.apache.lucene.util.hnsw.OnHeapHnswGraph; /** * Writes vector values and knn graphs to index segments. @@ -145,7 +142,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors = new Lucene91HnswVectorsReader.OffHeapVectorValues( vectors.dimension(), docsWithField.cardinality(), null, vectorDataInput); - OnHeapHnswGraph graph = + Lucene91OnHeapHnswGraph graph = offHeapVectors.size() == 0 ? null : writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction()); @@ -194,7 +191,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { long vectorIndexOffset, long vectorIndexLength, DocsWithFieldSet docsWithField, - OnHeapHnswGraph graph) + Lucene91OnHeapHnswGraph graph) throws IOException { meta.writeInt(field.number); meta.writeInt(field.getVectorSimilarityFunction().ordinal()); @@ -236,16 +233,20 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { } } - private OnHeapHnswGraph writeGraph( + private Lucene91OnHeapHnswGraph writeGraph( RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { // build graph - HnswGraphBuilder hnswGraphBuilder = - new HnswGraphBuilder( - vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed); + Lucene91HnswGraphBuilder hnswGraphBuilder = + new Lucene91HnswGraphBuilder( + vectorValues, + similarityFunction, + maxConn, + beamWidth, + Lucene91HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); + Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); // write vectors' neighbours on each level into the vectorIndex file int countOnLevel0 = graph.size(); @@ -253,7 +254,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter { NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); while (nodesOnLevel.hasNext()) { int node = nodesOnLevel.nextInt(); - NeighborArray neighbors = graph.getNeighbors(level, node); + Lucene91NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); // Destructively modify; it's ok we are discarding it after this diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java index d762033e96d..3b28b706890 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java @@ -58,9 +58,9 @@ import org.apache.lucene.util.hnsw.HnswGraph; *
    *
  • [int32] the number of neighbor nodes *
  • array[int32] the neighbor ordinals - *
  • array[int32] padding from empty integers if the number of neighbors less - * than the maximum number of connections (maxConn). Padding is equal to - * ((maxConn-the number of neighbours) * 4) bytes. + *
  • array[int32] padding if the number of the node's neighbors is less than + * the maximum number of connections allowed on this level. Padding is equal to + * ((maxConnOnLevel – the number of neighbours) * 4) bytes. *
* * diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java index a03b0b9b32b..1b1a120a6b2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java @@ -282,7 +282,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { final long vectorDataLength; final long vectorIndexOffset; final long vectorIndexLength; - final int maxConn; + final int M; final int numLevels; final int dimension; final int size; @@ -336,7 +336,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { } // read nodes by level - maxConn = input.readInt(); + M = input.readInt(); numLevels = input.readInt(); nodesByLevel = new int[numLevels][]; for (int level = 0; level < numLevels; level++) { @@ -359,10 +359,13 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { for (int level = 0; level < numLevels; level++) { if (level == 0) { graphOffsetsByLevel[level] = 0; + } else if (level == 1) { + int numNodesOnLevel0 = size; + graphOffsetsByLevel[level] = (1 + (M * 2)) * Integer.BYTES * numNodesOnLevel0; } else { - int numNodesOnPrevLevel = level == 1 ? size : nodesByLevel[level - 1].length; + int numNodesOnPrevLevel = nodesByLevel[level - 1].length; graphOffsetsByLevel[level] = - graphOffsetsByLevel[level - 1] + (1 + maxConn) * Integer.BYTES * numNodesOnPrevLevel; + graphOffsetsByLevel[level - 1] + (1 + M) * Integer.BYTES * numNodesOnPrevLevel; } } } @@ -382,6 +385,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { final int entryNode; final int size; final long bytesForConns; + final long bytesForConns0; int arcCount; int arcUpTo; @@ -394,7 +398,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0; this.size = entry.size(); this.graphOffsetsByLevel = entry.graphOffsetsByLevel; - this.bytesForConns = ((long) entry.maxConn + 1) * Integer.BYTES; + this.bytesForConns = ((long) entry.M + 1) * Integer.BYTES; + this.bytesForConns0 = ((long) (entry.M * 2) + 1) * Integer.BYTES; } @Override @@ -404,7 +409,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { ? targetOrd : Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd); assert targetIndex >= 0; - long graphDataOffset = graphOffsetsByLevel[level] + targetIndex * bytesForConns; + long graphDataOffset = + graphOffsetsByLevel[level] + targetIndex * (level == 0 ? bytesForConns0 : bytesForConns); // unsafe; no bounds checking dataIn.seek(graphDataOffset); arcCount = dataIn.readInt(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java index c2fc9b35975..e63ab2ce277 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -55,13 +55,12 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter { private final IndexOutput meta, vectorData, vectorIndex; private final int maxDoc; - private final int maxConn; + private final int M; private final int beamWidth; private boolean finished; - Lucene92HnswVectorsWriter(SegmentWriteState state, int maxConn, int beamWidth) - throws IOException { - this.maxConn = maxConn; + Lucene92HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException { + this.M = M; this.beamWidth = beamWidth; assert state.fieldInfos.hasVectorValues(); @@ -248,7 +247,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter { meta.writeLong(vectorData.getFilePointer() - start); } - meta.writeInt(maxConn); + meta.writeInt(M); // write graph nodes on each level if (graph == null) { meta.writeInt(0); @@ -274,13 +273,14 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter { // build graph HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder( - vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed); + vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); // write vectors' neighbours on each level into the vectorIndex file int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { + int maxConnOnLevel = level == 0 ? (M * 2) : M; NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); while (nodesOnLevel.hasNext()) { int node = nodesOnLevel.nextInt(); @@ -297,7 +297,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter { } // if number of connections < maxConn, add bogus values up to maxConn to have predictable // offsets - for (int i = size; i < maxConn; i++) { + for (int i = size; i < maxConnOnLevel; i++) { vectorIndex.writeInt(0); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index b7935497f14..b611d082c96 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -43,7 +43,7 @@ public final class HnswGraphBuilder { /** Random seed for level generation; public to expose for testing * */ public static long randSeed = DEFAULT_RAND_SEED; - private final int maxConn; + private final int M; // max number of connections on upper layers private final int beamWidth; private final double ml; private final NeighborArray scratch; @@ -68,8 +68,8 @@ public final class HnswGraphBuilder { * * @param vectors the vectors whose relations are represented by the graph - must provide a * different view over those vectors than the one used to add via addGraphNode. - * @param maxConn the number of connections to make when adding a new graph node; roughly speaking - * the graph fanout. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param seed the seed for a random number generator used during graph construction. Provide this * to ensure repeatable construction. @@ -77,26 +77,26 @@ public final class HnswGraphBuilder { public HnswGraphBuilder( RandomAccessVectorValuesProducer vectors, VectorSimilarityFunction similarityFunction, - int maxConn, + int M, int beamWidth, long seed) throws IOException { vectorValues = vectors.randomAccess(); buildVectors = vectors.randomAccess(); this.similarityFunction = Objects.requireNonNull(similarityFunction); - if (maxConn <= 0) { + if (M <= 0) { throw new IllegalArgumentException("maxConn must be positive"); } if (beamWidth <= 0) { throw new IllegalArgumentException("beamWidth must be positive"); } - this.maxConn = maxConn; + this.M = M; this.beamWidth = beamWidth; // normalization factor for level generation; currently not configurable - this.ml = 1 / Math.log(1.0 * maxConn); + this.ml = 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); int levelOfFirstNode = getRandomGraphLevel(ml, random); - this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed); + this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode, similarityFunction.reversed); this.graphSearcher = new HnswGraphSearcher( similarityFunction, @@ -104,7 +104,7 @@ public final class HnswGraphBuilder { new FixedBitSet(vectorValues.size())); bound = BoundsChecker.create(similarityFunction.reversed); // in scratch we store candidates in reverse order: worse candidates are first - scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed); + scratch = new NeighborArray(Math.max(beamWidth, M + 1), similarityFunction.reversed); } /** @@ -187,7 +187,8 @@ public final class HnswGraphBuilder { NeighborArray neighbors = hnsw.getNeighbors(level, node); assert neighbors.size() == 0; // new node popToScratch(candidates); - selectAndLinkDiverse(neighbors, scratch); + int maxConnOnLevel = level == 0 ? M * 2 : M; + selectAndLinkDiverse(neighbors, scratch, maxConnOnLevel); // Link the selected nodes to the new node, and the new node to the selected nodes (again // applying diversity heuristic) @@ -196,17 +197,17 @@ public final class HnswGraphBuilder { int nbr = neighbors.node[i]; NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); nbrNbr.insertSorted(node, neighbors.score[i]); - if (nbrNbr.size() > maxConn) { + if (nbrNbr.size() > maxConnOnLevel) { int indexToRemove = findWorstNonDiverse(nbrNbr); nbrNbr.removeIndex(indexToRemove); } } } - private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates) - throws IOException { - // Select the best maxConn neighbors of the new node, applying the diversity heuristic - for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) { + private void selectAndLinkDiverse( + NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException { + // Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic + for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) { // compare each neighbor (in distance order) against the closer neighbors selected so far, // only adding it if it is closer to the target than to any of the other selected neighbors int cNode = candidates.node[i]; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index e02fbcd9bda..b1a2436166f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -47,7 +47,7 @@ public final class HnswGraphSearcher { * @param candidates max heap that will track the candidate nodes to explore * @param visited bit set that will track nodes that have already been visited */ - HnswGraphSearcher( + public HnswGraphSearcher( VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) { this.similarityFunction = similarityFunction; this.candidates = candidates; @@ -112,7 +112,7 @@ public final class HnswGraphSearcher { * @param graph the graph values * @return a priority queue holding the closest neighbors found */ - NeighborQueue searchLevel( + public NeighborQueue searchLevel( float[] query, int topK, int level, diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java index cb58c608f61..a2c7253261b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java @@ -98,7 +98,7 @@ public class NeighborQueue { return (int) order.apply(heap.pop()); } - int[] nodes() { + public int[] nodes() { int size = size(); int[] nodes = new int[size]; for (int i = 0; i < size; i++) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 08cecd1f8ac..1dc0845ccd5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -30,7 +30,6 @@ import org.apache.lucene.util.ArrayUtil; */ public final class OnHeapHnswGraph extends HnswGraph { - private final int maxConn; private final boolean similarityReversed; private int numLevels; // the current number of levels in the graph private int entryNode; // the current graph entry node on the top level @@ -41,27 +40,30 @@ public final class OnHeapHnswGraph extends HnswGraph { // graph is a list of graph levels. // Each level is represented as List – nodes' connections on this level. - // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors + // Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond + // to vectors // added to HnswBuilder, and the node values are the ordinals of those vectors. // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. private final List> graph; + private final int nsize; + private final int nsize0; // KnnGraphValues iterator members private int upto; private NeighborArray cur; - OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) { - this.maxConn = maxConn; + OnHeapHnswGraph(int M, int levelOfFirstNode, boolean similarityReversed) { this.similarityReversed = similarityReversed; this.numLevels = levelOfFirstNode + 1; this.graph = new ArrayList<>(numLevels); this.entryNode = 0; - for (int i = 0; i < numLevels; i++) { + // Neighbours' size on upper levels (nsize) and level 0 (nsize0) + // We allocate extra space for neighbours, but then prune them to keep allowed maximum + this.nsize = M + 1; + this.nsize0 = (M * 2 + 1); + for (int l = 0; l < numLevels; l++) { graph.add(new ArrayList<>()); - // Typically with diversity criteria we see nodes not fully occupied; - // average fanout seems to be about 1/2 maxConn. - // There is some indexing time penalty for under-allocating, but saves RAM - graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false)); + graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, similarityReversed == false)); } this.nodesByLevel = new ArrayList<>(numLevels); @@ -121,8 +123,9 @@ public final class OnHeapHnswGraph extends HnswGraph { } } } - - graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false)); + graph + .get(level) + .add(new NeighborArray(level == 0 ? nsize0 : nsize, similarityReversed == false)); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index b2bf9cab30c..3a54c407070 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -64,7 +64,7 @@ public class TestKnnGraph extends LuceneTestCase { private static final String KNN_GRAPH_FIELD = "vector"; - private static int maxConn = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN; + private static int M = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN; private Codec codec; private VectorSimilarityFunction similarityFunction; @@ -73,15 +73,14 @@ public class TestKnnGraph extends LuceneTestCase { public void setup() { randSeed = random().nextLong(); if (random().nextBoolean()) { - maxConn = random().nextInt(256) + 3; + M = random().nextInt(256) + 3; } codec = new Lucene92Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene92HnswVectorsFormat( - maxConn, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene92HnswVectorsFormat(M, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH); } }; @@ -91,7 +90,7 @@ public class TestKnnGraph extends LuceneTestCase { @After public void cleanup() { - maxConn = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN; + M = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN; } /** Basic test of creating documents in a graph */ @@ -263,7 +262,7 @@ public class TestKnnGraph extends LuceneTestCase { int[][][] copyGraph(HnswGraph graphValues) throws IOException { int[][][] graph = new int[graphValues.numLevels()][][]; int size = graphValues.size(); - int[] scratch = new int[maxConn]; + int[] scratch = new int[M * 2]; for (int level = 0; level < graphValues.numLevels(); level++) { NodesIterator nodesItr = graphValues.getNodesOnLevel(level); @@ -483,10 +482,13 @@ public class TestKnnGraph extends LuceneTestCase { // For each level of the graph assert that: // 1. There are no orphan nodes without any friends // 2. If orphans are found, than the level must contain only 0 or a single node - // 3. If the number of nodes on the level doesn't exceed maxConn, assert that the graph is + // 3. If the number of nodes on the level doesn't exceed maxConnOnLevel, assert that the + // graph is // fully connected, i.e. any node is reachable from any other node. - // 4. If the number of nodes on the level exceeds maxConn, assert that maxConn is respected. + // 4. If the number of nodes on the level exceeds maxConnOnLevel, assert that maxConnOnLevel + // is respected. for (int level = 0; level < graphValues.numLevels(); level++) { + int maxConnOnLevel = level == 0 ? M * 2 : M; int[][] graphOnLevel = new int[graphValues.size()][]; int countOnLevel = 0; boolean foundOrphan = false; @@ -508,7 +510,6 @@ public class TestKnnGraph extends LuceneTestCase { } countOnLevel++; } - // System.out.println("Level[" + level + "] has [" + nodesCount + "] nodes."); assertEquals(nodesItr.size(), countOnLevel); assertFalse("No nodes on level [" + level + "]", countOnLevel == 0); if (countOnLevel == 1) { @@ -517,13 +518,13 @@ public class TestKnnGraph extends LuceneTestCase { } else { assertFalse( "Graph has orphan nodes with no friends on level [" + level + "]", foundOrphan); - if (maxConn > countOnLevel) { + if (maxConnOnLevel > countOnLevel) { // assert that the graph is fully connected, // i.e. any node can be reached from any other node assertConnected(graphOnLevel); } else { // assert that max-connections was respected - assertMaxConn(graphOnLevel, maxConn); + assertMaxConn(graphOnLevel, maxConnOnLevel); } } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index f0068f03077..43c470513b4 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -62,14 +62,14 @@ public class TestHnswGraph extends LuceneTestCase { RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random()); RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy(); - int maxConn = random().nextInt(10) + 5; + int M = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 5; long seed = random().nextLong(); VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[ random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; HnswGraphBuilder builder = - new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed); + new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed); HnswGraph hnsw = builder.build(vectors); // Recreate the graph while indexing with the same random seed and write it out @@ -84,7 +84,7 @@ public class TestHnswGraph extends LuceneTestCase { new Lucene92Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene92HnswVectorsFormat(maxConn, beamWidth); + return new Lucene92HnswVectorsFormat(M, beamWidth); } }); try (IndexWriter iw = new IndexWriter(dir, iwc)) { @@ -153,12 +153,11 @@ public class TestHnswGraph extends LuceneTestCase { // ensuring that we have all the distance functions, comparators, priority queues and so on // oriented in the right directions public void testAknnDiverse() throws IOException { - int maxConn = 10; int nDoc = 100; CircularVectorValues vectors = new CircularVectorValues(nDoc); HnswGraphBuilder builder = new HnswGraphBuilder( - vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt()); + vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors); // run some searches NeighborQueue nn = @@ -193,11 +192,10 @@ public class TestHnswGraph extends LuceneTestCase { public void testSearchWithAcceptOrds() throws IOException { int nDoc = 100; - int maxConn = 16; CircularVectorValues vectors = new CircularVectorValues(nDoc); HnswGraphBuilder builder = new HnswGraphBuilder( - vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt()); + vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors); // the first 10 docs must not be deleted to ensure the expected recall Bits acceptOrds = createRandomAcceptOrds(10, vectors.size); @@ -224,11 +222,10 @@ public class TestHnswGraph extends LuceneTestCase { public void testSearchWithSelectiveAcceptOrds() throws IOException { int nDoc = 100; - int maxConn = 16; CircularVectorValues vectors = new CircularVectorValues(nDoc); HnswGraphBuilder builder = new HnswGraphBuilder( - vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt()); + vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors); // Only mark a few vectors as accepted BitSet acceptOrds = new FixedBitSet(vectors.size); @@ -290,11 +287,10 @@ public class TestHnswGraph extends LuceneTestCase { public void testVisitedLimit() throws IOException { int nDoc = 500; - int maxConn = 16; CircularVectorValues vectors = new CircularVectorValues(nDoc); HnswGraphBuilder builder = new HnswGraphBuilder( - vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt()); + vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors); int topK = 50; @@ -396,9 +392,7 @@ public class TestHnswGraph extends LuceneTestCase { builder.addGraphNode(4, vectors.vectorValue(4)); // 4 is the same distance from 0 that 2 is; we leave the existing node in place assertLevel0Neighbors(builder.hnsw, 0, 1, 2); - // 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so - // replace it - assertLevel0Neighbors(builder.hnsw, 1, 0, 4); + assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4); assertLevel0Neighbors(builder.hnsw, 2, 0); // 1 survives the diversity check assertLevel0Neighbors(builder.hnsw, 3, 1, 4); @@ -406,11 +400,11 @@ public class TestHnswGraph extends LuceneTestCase { builder.addGraphNode(5, vectors.vectorValue(5)); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); - assertLevel0Neighbors(builder.hnsw, 1, 0, 5); + assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5); assertLevel0Neighbors(builder.hnsw, 2, 0); // even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs assertLevel0Neighbors(builder.hnsw, 3, 1, 4); - assertLevel0Neighbors(builder.hnsw, 4, 3, 5); + assertLevel0Neighbors(builder.hnsw, 4, 1, 3, 5); assertLevel0Neighbors(builder.hnsw, 5, 1, 4); } @@ -428,14 +422,13 @@ public class TestHnswGraph extends LuceneTestCase { public void testRandom() throws IOException { int size = atLeast(100); int dim = atLeast(10); - int maxConn = 10; RandomVectorValues vectors = new RandomVectorValues(size, dim, random()); VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[ random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; int topK = 5; HnswGraphBuilder builder = - new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong()); + new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong()); OnHeapHnswGraph hnsw = builder.build(vectors); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);