mirror of https://github.com/apache/lucene.git
LUCENE-9848 Sort HNSW graph neighbors for construction (#862)
* LUCENE-9848 Sort HNSW graph neighbors for construction Sort HNSW graph neighbors when applying diversity criterion During HNSW graph construction, when a node has already a number of connections larger than maximum allowed (maxConn), we need to prune its connections using a diversity criteria to limit the number of connections to maxConn. Currently when we add reverse connections to already existing nodes, we don't keep them sorted. Thus later, when we apply diversity criteria we may prune not the worst most distant non-diverse nodes. This patch makes sure that neighbours connections are always sorted from best (closest) to worst (distant), and during the application of diversity criteria processes nodes from worst to best. This path does the following: - enhance NeighborArray to always keep neighbour nodes sorted according to their scores (in desc or asc order). Make NeighborArray aware in which order the nodes should be sorted. - make OnHeapHnswGraph aware of the order of similarity function - make HnswGraphBuilder apply diversity criteria from worst to best nodes - create Lucene90NeighborArray to keep the previous logic of NeighborArray for Lucene90Codec
This commit is contained in:
parent
c3d47507e9
commit
dc6a7f9468
|
@ -26,7 +26,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.InfoStream;
|
import org.apache.lucene.util.InfoStream;
|
||||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -47,7 +46,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
|
|
||||||
private final int maxConn;
|
private final int maxConn;
|
||||||
private final int beamWidth;
|
private final int beamWidth;
|
||||||
private final NeighborArray scratch;
|
private final Lucene90NeighborArray scratch;
|
||||||
|
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final RandomAccessVectorValues vectorValues;
|
private final RandomAccessVectorValues vectorValues;
|
||||||
|
@ -93,7 +92,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
|
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
|
||||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||||
random = new SplittableRandom(seed);
|
random = new SplittableRandom(seed);
|
||||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -173,7 +172,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
||||||
* since the node is new and has no prior neighbors).
|
* since the node is new and has no prior neighbors).
|
||||||
*/
|
*/
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(node);
|
Lucene90NeighborArray neighbors = hnsw.getNeighbors(node);
|
||||||
assert neighbors.size() == 0; // new node
|
assert neighbors.size() == 0; // new node
|
||||||
popToScratch(candidates);
|
popToScratch(candidates);
|
||||||
selectDiverse(neighbors, scratch);
|
selectDiverse(neighbors, scratch);
|
||||||
|
@ -183,7 +182,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
int size = neighbors.size();
|
int size = neighbors.size();
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
int nbr = neighbors.node()[i];
|
int nbr = neighbors.node()[i];
|
||||||
NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
|
Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
|
||||||
nbrNbr.add(node, neighbors.score()[i]);
|
nbrNbr.add(node, neighbors.score()[i]);
|
||||||
if (nbrNbr.size() > maxConn) {
|
if (nbrNbr.size() > maxConn) {
|
||||||
diversityUpdate(nbrNbr);
|
diversityUpdate(nbrNbr);
|
||||||
|
@ -191,7 +190,8 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
|
private void selectDiverse(Lucene90NeighborArray neighbors, Lucene90NeighborArray candidates)
|
||||||
|
throws IOException {
|
||||||
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
||||||
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
||||||
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
||||||
|
@ -228,7 +228,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
private boolean diversityCheck(
|
private boolean diversityCheck(
|
||||||
float[] candidate,
|
float[] candidate,
|
||||||
float score,
|
float score,
|
||||||
NeighborArray neighbors,
|
Lucene90NeighborArray neighbors,
|
||||||
RandomAccessVectorValues vectorValues)
|
RandomAccessVectorValues vectorValues)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
bound.set(score);
|
bound.set(score);
|
||||||
|
@ -242,7 +242,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void diversityUpdate(NeighborArray neighbors) throws IOException {
|
private void diversityUpdate(Lucene90NeighborArray neighbors) throws IOException {
|
||||||
assert neighbors.size() == maxConn + 1;
|
assert neighbors.size() == maxConn + 1;
|
||||||
int replacePoint = findNonDiverse(neighbors);
|
int replacePoint = findNonDiverse(neighbors);
|
||||||
if (replacePoint == -1) {
|
if (replacePoint == -1) {
|
||||||
|
@ -262,7 +262,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// scan neighbors looking for diversity violations
|
// scan neighbors looking for diversity violations
|
||||||
private int findNonDiverse(NeighborArray neighbors) throws IOException {
|
private int findNonDiverse(Lucene90NeighborArray neighbors) throws IOException {
|
||||||
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
||||||
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
||||||
// them, drop it
|
// them, drop it
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.backward_codecs.lucene90;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
|
||||||
|
* of growable arrays.
|
||||||
|
*
|
||||||
|
* @lucene.internal
|
||||||
|
*/
|
||||||
|
public class Lucene90NeighborArray {
|
||||||
|
|
||||||
|
private int size;
|
||||||
|
|
||||||
|
float[] score;
|
||||||
|
int[] node;
|
||||||
|
|
||||||
|
/** Create a neighbour array with the given initial size */
|
||||||
|
public Lucene90NeighborArray(int maxSize) {
|
||||||
|
node = new int[maxSize];
|
||||||
|
score = new float[maxSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Add a new node with a score */
|
||||||
|
public void add(int newNode, float newScore) {
|
||||||
|
if (size == node.length - 1) {
|
||||||
|
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
||||||
|
score = ArrayUtil.growExact(score, node.length);
|
||||||
|
}
|
||||||
|
node[size] = newNode;
|
||||||
|
score[size] = newScore;
|
||||||
|
++size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get the size, the number of nodes added so far */
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Direct access to the internal list of node ids; provided for efficient writing of the graph
|
||||||
|
*
|
||||||
|
* @lucene.internal
|
||||||
|
*/
|
||||||
|
public int[] node() {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Direct access to the internal list of scores
|
||||||
|
*
|
||||||
|
* @lucene.internal
|
||||||
|
*/
|
||||||
|
public float[] score() {
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Clear all the nodes in the array */
|
||||||
|
public void clear() {
|
||||||
|
size = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Remove the last nodes from the array */
|
||||||
|
public void removeLast() {
|
||||||
|
size--;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "NeighborArray[" + size + "]";
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,7 +29,6 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.SparseFixedBitSet;
|
import org.apache.lucene.util.SparseFixedBitSet;
|
||||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -43,17 +42,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to
|
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to
|
||||||
// HnswBuilder, and the
|
// HnswBuilder, and the
|
||||||
// node values are the ordinals of those vectors.
|
// node values are the ordinals of those vectors.
|
||||||
private final List<NeighborArray> graph;
|
private final List<Lucene90NeighborArray> graph;
|
||||||
|
|
||||||
// KnnGraphValues iterator members
|
// KnnGraphValues iterator members
|
||||||
private int upto;
|
private int upto;
|
||||||
private NeighborArray cur;
|
private Lucene90NeighborArray cur;
|
||||||
|
|
||||||
Lucene90OnHeapHnswGraph(int maxConn) {
|
Lucene90OnHeapHnswGraph(int maxConn) {
|
||||||
graph = new ArrayList<>();
|
graph = new ArrayList<>();
|
||||||
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
|
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
|
||||||
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
|
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
|
||||||
graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
|
graph.add(new Lucene90NeighborArray(Math.max(32, maxConn / 4)));
|
||||||
this.maxConn = maxConn;
|
this.maxConn = maxConn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +161,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
*
|
*
|
||||||
* @param node the node whose neighbors are returned
|
* @param node the node whose neighbors are returned
|
||||||
*/
|
*/
|
||||||
public NeighborArray getNeighbors(int node) {
|
public Lucene90NeighborArray getNeighbors(int node) {
|
||||||
return graph.get(node);
|
return graph.get(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +171,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
int addNode() {
|
int addNode() {
|
||||||
graph.add(new NeighborArray(maxConn + 1));
|
graph.add(new Lucene90NeighborArray(maxConn + 1));
|
||||||
return graph.size() - 1;
|
return graph.size() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes vector values and knn graphs to index segments.
|
* Writes vector values and knn graphs to index segments.
|
||||||
|
@ -247,7 +246,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
// write graph
|
// write graph
|
||||||
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
|
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
|
||||||
|
|
||||||
NeighborArray neighbors = graph.getNeighbors(ord);
|
Lucene90NeighborArray neighbors = graph.getNeighbors(ord);
|
||||||
int size = neighbors.size();
|
int size = neighbors.size();
|
||||||
|
|
||||||
// Destructively modify; it's ok we are discarding it after this
|
// Destructively modify; it's ok we are discarding it after this
|
||||||
|
|
|
@ -95,14 +95,15 @@ public final class HnswGraphBuilder {
|
||||||
this.ml = 1 / Math.log(1.0 * maxConn);
|
this.ml = 1 / Math.log(1.0 * maxConn);
|
||||||
this.random = new SplittableRandom(seed);
|
this.random = new SplittableRandom(seed);
|
||||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||||
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed);
|
||||||
this.graphSearcher =
|
this.graphSearcher =
|
||||||
new HnswGraphSearcher(
|
new HnswGraphSearcher(
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
|
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
|
||||||
new FixedBitSet(vectorValues.size()));
|
new FixedBitSet(vectorValues.size()));
|
||||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||||
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
|
// in scratch we store candidates in reverse order: worse candidates are first
|
||||||
|
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -176,11 +177,6 @@ public final class HnswGraphBuilder {
|
||||||
return now;
|
return now;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* TODO: we are not maintaining nodes in strict score order; the forward links
|
|
||||||
* are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should
|
|
||||||
* work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap?
|
|
||||||
* But first we should just see if sorting makes a significant difference.
|
|
||||||
*/
|
|
||||||
private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
|
private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
|
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
|
||||||
|
@ -190,7 +186,7 @@ public final class HnswGraphBuilder {
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(level, node);
|
NeighborArray neighbors = hnsw.getNeighbors(level, node);
|
||||||
assert neighbors.size() == 0; // new node
|
assert neighbors.size() == 0; // new node
|
||||||
popToScratch(candidates);
|
popToScratch(candidates);
|
||||||
selectDiverse(neighbors, scratch);
|
selectAndLinkDiverse(neighbors, scratch);
|
||||||
|
|
||||||
// Link the selected nodes to the new node, and the new node to the selected nodes (again
|
// Link the selected nodes to the new node, and the new node to the selected nodes (again
|
||||||
// applying diversity heuristic)
|
// applying diversity heuristic)
|
||||||
|
@ -198,14 +194,16 @@ public final class HnswGraphBuilder {
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
int nbr = neighbors.node[i];
|
int nbr = neighbors.node[i];
|
||||||
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
|
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
|
||||||
nbrNbr.add(node, neighbors.score[i]);
|
nbrNbr.insertSorted(node, neighbors.score[i]);
|
||||||
if (nbrNbr.size() > maxConn) {
|
if (nbrNbr.size() > maxConn) {
|
||||||
diversityUpdate(nbrNbr);
|
int indexToRemove = findWorstNonDiverse(nbrNbr);
|
||||||
|
nbrNbr.removeIndex(indexToRemove);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
|
private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates)
|
||||||
|
throws IOException {
|
||||||
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
||||||
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
||||||
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
||||||
|
@ -256,44 +254,26 @@ public final class HnswGraphBuilder {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void diversityUpdate(NeighborArray neighbors) throws IOException {
|
/**
|
||||||
assert neighbors.size() == maxConn + 1;
|
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
|
||||||
int replacePoint = findNonDiverse(neighbors);
|
* neighbours
|
||||||
if (replacePoint == -1) {
|
*/
|
||||||
// none found; check score against worst existing neighbor
|
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
||||||
bound.set(neighbors.score[0]);
|
for (int i = neighbors.size() - 1; i > 0; i--) {
|
||||||
if (bound.check(neighbors.score[maxConn])) {
|
int cNode = neighbors.node[i];
|
||||||
// drop the new neighbor; it is not competitive and there were no diversity failures
|
float[] cVector = vectorValues.vectorValue(cNode);
|
||||||
neighbors.removeLast();
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
replacePoint = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
neighbors.node[replacePoint] = neighbors.node[maxConn];
|
|
||||||
neighbors.score[replacePoint] = neighbors.score[maxConn];
|
|
||||||
neighbors.removeLast();
|
|
||||||
}
|
|
||||||
|
|
||||||
// scan neighbors looking for diversity violations
|
|
||||||
private int findNonDiverse(NeighborArray neighbors) throws IOException {
|
|
||||||
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
|
||||||
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
|
||||||
// them, drop it
|
|
||||||
int nbrNode = neighbors.node[i];
|
|
||||||
bound.set(neighbors.score[i]);
|
bound.set(neighbors.score[i]);
|
||||||
float[] nbrVector = vectorValues.vectorValue(nbrNode);
|
// check the candidate against its better-scoring neighbors
|
||||||
for (int j = maxConn; j > i; j--) {
|
for (int j = i - 1; j >= 0; j--) {
|
||||||
float diversityCheck =
|
float diversityCheck =
|
||||||
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
|
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||||
|
// node i is too similar to node j given its score relative to the base node
|
||||||
if (bound.check(diversityCheck) == false) {
|
if (bound.check(diversityCheck) == false) {
|
||||||
// node j is too similar to node i given its score relative to the base node
|
|
||||||
// replace it with the new node, which is at [maxConn]
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
return neighbors.size() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
||||||
|
|
|
@ -17,36 +17,67 @@
|
||||||
|
|
||||||
package org.apache.lucene.util.hnsw;
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
|
* NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
|
||||||
* of growable arrays.
|
* of growable arrays. Nodes are arranged in the sorted order of their scores in descending order
|
||||||
|
* (if scoresDescOrder is true), or in the ascending order of their scores (if scoresDescOrder is
|
||||||
|
* false)
|
||||||
*
|
*
|
||||||
* @lucene.internal
|
* @lucene.internal
|
||||||
*/
|
*/
|
||||||
public class NeighborArray {
|
public class NeighborArray {
|
||||||
|
private final boolean scoresDescOrder;
|
||||||
private int size;
|
private int size;
|
||||||
|
|
||||||
float[] score;
|
float[] score;
|
||||||
int[] node;
|
int[] node;
|
||||||
|
|
||||||
public NeighborArray(int maxSize) {
|
public NeighborArray(int maxSize, boolean descOrder) {
|
||||||
node = new int[maxSize];
|
node = new int[maxSize];
|
||||||
score = new float[maxSize];
|
score = new float[maxSize];
|
||||||
|
this.scoresDescOrder = descOrder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
|
||||||
|
* nodes.
|
||||||
|
*/
|
||||||
public void add(int newNode, float newScore) {
|
public void add(int newNode, float newScore) {
|
||||||
if (size == node.length - 1) {
|
if (size == node.length - 1) {
|
||||||
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
||||||
score = ArrayUtil.growExact(score, node.length);
|
score = ArrayUtil.growExact(score, node.length);
|
||||||
}
|
}
|
||||||
|
if (size > 0) {
|
||||||
|
float previousScore = score[size - 1];
|
||||||
|
assert ((scoresDescOrder && (previousScore >= newScore))
|
||||||
|
|| (scoresDescOrder == false && (previousScore <= newScore)))
|
||||||
|
: "Nodes are added in the incorrect order!";
|
||||||
|
}
|
||||||
node[size] = newNode;
|
node[size] = newNode;
|
||||||
score[size] = newScore;
|
score[size] = newScore;
|
||||||
++size;
|
++size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Add a new node to the NeighborArray into a correct sort position according to its score. */
|
||||||
|
public void insertSorted(int newNode, float newScore) {
|
||||||
|
if (size == node.length - 1) {
|
||||||
|
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
|
||||||
|
score = ArrayUtil.growExact(score, node.length);
|
||||||
|
}
|
||||||
|
int insertionPoint =
|
||||||
|
scoresDescOrder
|
||||||
|
? descSortFindRightMostInsertionPoint(newScore)
|
||||||
|
: ascSortFindRightMostInsertionPoint(newScore);
|
||||||
|
System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
|
||||||
|
System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
|
||||||
|
node[insertionPoint] = newNode;
|
||||||
|
score[insertionPoint] = newScore;
|
||||||
|
++size;
|
||||||
|
}
|
||||||
|
|
||||||
public int size() {
|
public int size() {
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
@ -72,8 +103,39 @@ public class NeighborArray {
|
||||||
size--;
|
size--;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void removeIndex(int idx) {
|
||||||
|
System.arraycopy(node, idx + 1, node, idx, size - idx);
|
||||||
|
System.arraycopy(score, idx + 1, score, idx, size - idx);
|
||||||
|
size--;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "NeighborArray[" + size + "]";
|
return "NeighborArray[" + size + "]";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private int ascSortFindRightMostInsertionPoint(float newScore) {
|
||||||
|
int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
|
||||||
|
if (insertionPoint >= 0) {
|
||||||
|
// find the right most position with the same score
|
||||||
|
while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
|
||||||
|
insertionPoint++;
|
||||||
|
}
|
||||||
|
insertionPoint++;
|
||||||
|
} else {
|
||||||
|
insertionPoint = -insertionPoint - 1;
|
||||||
|
}
|
||||||
|
return insertionPoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int descSortFindRightMostInsertionPoint(float newScore) {
|
||||||
|
int start = 0;
|
||||||
|
int end = size - 1;
|
||||||
|
while (start <= end) {
|
||||||
|
int mid = (start + end) / 2;
|
||||||
|
if (score[mid] < newScore) end = mid - 1;
|
||||||
|
else start = mid + 1;
|
||||||
|
}
|
||||||
|
return start;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.lucene.util.NumericUtils;
|
||||||
*/
|
*/
|
||||||
public class NeighborQueue {
|
public class NeighborQueue {
|
||||||
|
|
||||||
private static enum Order {
|
private enum Order {
|
||||||
NATURAL {
|
NATURAL {
|
||||||
@Override
|
@Override
|
||||||
long apply(long v) {
|
long apply(long v) {
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.util.ArrayUtil;
|
||||||
public final class OnHeapHnswGraph extends HnswGraph {
|
public final class OnHeapHnswGraph extends HnswGraph {
|
||||||
|
|
||||||
private final int maxConn;
|
private final int maxConn;
|
||||||
|
private final boolean similarityReversed;
|
||||||
private int numLevels; // the current number of levels in the graph
|
private int numLevels; // the current number of levels in the graph
|
||||||
private int entryNode; // the current graph entry node on the top level
|
private int entryNode; // the current graph entry node on the top level
|
||||||
|
|
||||||
|
@ -49,8 +50,9 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
||||||
private int upto;
|
private int upto;
|
||||||
private NeighborArray cur;
|
private NeighborArray cur;
|
||||||
|
|
||||||
OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
|
OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) {
|
||||||
this.maxConn = maxConn;
|
this.maxConn = maxConn;
|
||||||
|
this.similarityReversed = similarityReversed;
|
||||||
this.numLevels = levelOfFirstNode + 1;
|
this.numLevels = levelOfFirstNode + 1;
|
||||||
this.graph = new ArrayList<>(numLevels);
|
this.graph = new ArrayList<>(numLevels);
|
||||||
this.entryNode = 0;
|
this.entryNode = 0;
|
||||||
|
@ -59,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
||||||
// Typically with diversity criteria we see nodes not fully occupied;
|
// Typically with diversity criteria we see nodes not fully occupied;
|
||||||
// average fanout seems to be about 1/2 maxConn.
|
// average fanout seems to be about 1/2 maxConn.
|
||||||
// There is some indexing time penalty for under-allocating, but saves RAM
|
// There is some indexing time penalty for under-allocating, but saves RAM
|
||||||
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
|
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||||
|
@ -120,7 +122,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.get(level).add(new NeighborArray(maxConn + 1));
|
graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -276,7 +276,8 @@ public class KnnGraphTester {
|
||||||
for (int i = 0; i < hnsw.size(); i++) {
|
for (int i = 0; i < hnsw.size(); i++) {
|
||||||
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||||
System.out.printf(Locale.ROOT, "%5d", i);
|
System.out.printf(Locale.ROOT, "%5d", i);
|
||||||
NeighborArray sorted = new NeighborArray(neighbors.size());
|
NeighborArray sorted =
|
||||||
|
new NeighborArray(neighbors.size(), similarityFunction.reversed == false);
|
||||||
for (int j = 0; j < neighbors.size(); j++) {
|
for (int j = 0; j < neighbors.size(); j++) {
|
||||||
int node = neighbors.node[j];
|
int node = neighbors.node[j];
|
||||||
float score = neighbors.score[j];
|
float score = neighbors.score[j];
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
|
||||||
|
public class TestNeighborArray extends LuceneTestCase {
|
||||||
|
|
||||||
|
public void testScoresDescOrder() {
|
||||||
|
NeighborArray neighbors = new NeighborArray(10, true);
|
||||||
|
neighbors.add(0, 1);
|
||||||
|
neighbors.add(1, 0.8f);
|
||||||
|
|
||||||
|
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f));
|
||||||
|
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
|
||||||
|
|
||||||
|
neighbors.insertSorted(3, 0.9f);
|
||||||
|
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 3, 1}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(4, 1f);
|
||||||
|
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(5, 1.1f);
|
||||||
|
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(6, 0.8f);
|
||||||
|
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(7, 0.8f);
|
||||||
|
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeIndex(2);
|
||||||
|
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeIndex(0);
|
||||||
|
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeIndex(4);
|
||||||
|
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeLast();
|
||||||
|
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 3, 1}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(8, 0.9f);
|
||||||
|
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testScoresAscOrder() {
|
||||||
|
NeighborArray neighbors = new NeighborArray(10, false);
|
||||||
|
neighbors.add(0, 0.1f);
|
||||||
|
neighbors.add(1, 0.3f);
|
||||||
|
|
||||||
|
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f));
|
||||||
|
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
|
||||||
|
|
||||||
|
neighbors.insertSorted(3, 0.3f);
|
||||||
|
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(4, 0.2f);
|
||||||
|
assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(5, 0.05f);
|
||||||
|
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(6, 0.2f);
|
||||||
|
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(7, 0.2f);
|
||||||
|
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeIndex(2);
|
||||||
|
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeIndex(0);
|
||||||
|
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeIndex(4);
|
||||||
|
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
|
||||||
|
|
||||||
|
neighbors.removeLast();
|
||||||
|
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {0, 6, 7}, neighbors);
|
||||||
|
|
||||||
|
neighbors.insertSorted(8, 0.01f);
|
||||||
|
assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors);
|
||||||
|
asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
|
||||||
|
for (int i = 0; i < scores.length; i++) {
|
||||||
|
assertEquals(scores[i], neighbors.score[i], 0.01f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void asserNodesEqual(int[] nodes, NeighborArray neighbors) {
|
||||||
|
for (int i = 0; i < nodes.length; i++) {
|
||||||
|
assertEquals(nodes[i], neighbors.node[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue