LUCENE-9848 Sort HNSW graph neighbors for construction (#862)

* LUCENE-9848 Sort HNSW graph neighbors for construction

Sort HNSW graph neighbors when applying diversity criterion

During HNSW graph construction, when a node has already a number of
connections larger than maximum allowed (maxConn), we need to prune
its connections using a diversity criteria to limit the number of
connections to maxConn.

Currently when we add reverse connections to already existing nodes,
we don't keep them sorted. Thus later, when we apply diversity criteria
we may prune not the worst most distant non-diverse nodes.

This patch makes sure that neighbours connections are always sorted
from best (closest) to worst (distant), and during the application
of diversity criteria processes nodes from worst to best.

This path does the following:
- enhance NeighborArray to always keep neighbour nodes sorted according
  to their scores (in desc or asc order). Make NeighborArray aware in
  which order the nodes should be sorted.
- make OnHeapHnswGraph aware of the order of similarity function
- make HnswGraphBuilder apply diversity criteria from worst to
  best nodes
- create Lucene90NeighborArray to keep the previous logic of
  NeighborArray for Lucene90Codec
This commit is contained in:
Mayya Sharipova 2022-05-04 14:15:14 -04:00 committed by GitHub
parent c3d47507e9
commit dc6a7f9468
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 332 additions and 67 deletions

View File

@ -26,7 +26,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
@ -47,7 +46,7 @@ public final class Lucene90HnswGraphBuilder {
private final int maxConn;
private final int beamWidth;
private final NeighborArray scratch;
private final Lucene90NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
@ -93,7 +92,7 @@ public final class Lucene90HnswGraphBuilder {
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
bound = BoundsChecker.create(similarityFunction.reversed);
random = new SplittableRandom(seed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
}
/**
@ -173,7 +172,7 @@ public final class Lucene90HnswGraphBuilder {
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
* since the node is new and has no prior neighbors).
*/
NeighborArray neighbors = hnsw.getNeighbors(node);
Lucene90NeighborArray neighbors = hnsw.getNeighbors(node);
assert neighbors.size() == 0; // new node
popToScratch(candidates);
selectDiverse(neighbors, scratch);
@ -183,7 +182,7 @@ public final class Lucene90HnswGraphBuilder {
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node()[i];
NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
nbrNbr.add(node, neighbors.score()[i]);
if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr);
@ -191,7 +190,8 @@ public final class Lucene90HnswGraphBuilder {
}
}
private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
private void selectDiverse(Lucene90NeighborArray neighbors, Lucene90NeighborArray candidates)
throws IOException {
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
@ -228,7 +228,7 @@ public final class Lucene90HnswGraphBuilder {
private boolean diversityCheck(
float[] candidate,
float score,
NeighborArray neighbors,
Lucene90NeighborArray neighbors,
RandomAccessVectorValues vectorValues)
throws IOException {
bound.set(score);
@ -242,7 +242,7 @@ public final class Lucene90HnswGraphBuilder {
return true;
}
private void diversityUpdate(NeighborArray neighbors) throws IOException {
private void diversityUpdate(Lucene90NeighborArray neighbors) throws IOException {
assert neighbors.size() == maxConn + 1;
int replacePoint = findNonDiverse(neighbors);
if (replacePoint == -1) {
@ -262,7 +262,7 @@ public final class Lucene90HnswGraphBuilder {
}
// scan neighbors looking for diversity violations
private int findNonDiverse(NeighborArray neighbors) throws IOException {
private int findNonDiverse(Lucene90NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i >= 0; i--) {
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
// them, drop it

View File

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene90;
import org.apache.lucene.util.ArrayUtil;
/**
* NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
* of growable arrays.
*
* @lucene.internal
*/
public class Lucene90NeighborArray {
private int size;
float[] score;
int[] node;
/** Create a neighbour array with the given initial size */
public Lucene90NeighborArray(int maxSize) {
node = new int[maxSize];
score = new float[maxSize];
}
/** Add a new node with a score */
public void add(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
score = ArrayUtil.growExact(score, node.length);
}
node[size] = newNode;
score[size] = newScore;
++size;
}
/** Get the size, the number of nodes added so far */
public int size() {
return size;
}
/**
* Direct access to the internal list of node ids; provided for efficient writing of the graph
*
* @lucene.internal
*/
public int[] node() {
return node;
}
/**
* Direct access to the internal list of scores
*
* @lucene.internal
*/
public float[] score() {
return score;
}
/** Clear all the nodes in the array */
public void clear() {
size = 0;
}
/** Remove the last nodes from the array */
public void removeLast() {
size--;
}
@Override
public String toString() {
return "NeighborArray[" + size + "]";
}
}

View File

@ -29,7 +29,6 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
@ -43,17 +42,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to
// HnswBuilder, and the
// node values are the ordinals of those vectors.
private final List<NeighborArray> graph;
private final List<Lucene90NeighborArray> graph;
// KnnGraphValues iterator members
private int upto;
private NeighborArray cur;
private Lucene90NeighborArray cur;
Lucene90OnHeapHnswGraph(int maxConn) {
graph = new ArrayList<>();
// Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
// about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
graph.add(new Lucene90NeighborArray(Math.max(32, maxConn / 4)));
this.maxConn = maxConn;
}
@ -162,7 +161,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
*
* @param node the node whose neighbors are returned
*/
public NeighborArray getNeighbors(int node) {
public Lucene90NeighborArray getNeighbors(int node) {
return graph.get(node);
}
@ -172,7 +171,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
}
int addNode() {
graph.add(new NeighborArray(maxConn + 1));
graph.add(new Lucene90NeighborArray(maxConn + 1));
return graph.size() - 1;
}

View File

@ -35,7 +35,6 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.NeighborArray;
/**
* Writes vector values and knn graphs to index segments.
@ -247,7 +246,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
// write graph
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
NeighborArray neighbors = graph.getNeighbors(ord);
Lucene90NeighborArray neighbors = graph.getNeighbors(ord);
int size = neighbors.size();
// Destructively modify; it's ok we are discarding it after this

View File

@ -95,14 +95,15 @@ public final class HnswGraphBuilder {
this.ml = 1 / Math.log(1.0 * maxConn);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
new FixedBitSet(vectorValues.size()));
bound = BoundsChecker.create(similarityFunction.reversed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
// in scratch we store candidates in reverse order: worse candidates are first
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed);
}
/**
@ -176,11 +177,6 @@ public final class HnswGraphBuilder {
return now;
}
/* TODO: we are not maintaining nodes in strict score order; the forward links
* are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should
* work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap?
* But first we should just see if sorting makes a significant difference.
*/
private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
@ -190,7 +186,7 @@ public final class HnswGraphBuilder {
NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node
popToScratch(candidates);
selectDiverse(neighbors, scratch);
selectAndLinkDiverse(neighbors, scratch);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
@ -198,14 +194,16 @@ public final class HnswGraphBuilder {
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
nbrNbr.add(node, neighbors.score[i]);
nbrNbr.insertSorted(node, neighbors.score[i]);
if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr);
int indexToRemove = findWorstNonDiverse(nbrNbr);
nbrNbr.removeIndex(indexToRemove);
}
}
}
private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates)
throws IOException {
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
@ -256,44 +254,26 @@ public final class HnswGraphBuilder {
return true;
}
private void diversityUpdate(NeighborArray neighbors) throws IOException {
assert neighbors.size() == maxConn + 1;
int replacePoint = findNonDiverse(neighbors);
if (replacePoint == -1) {
// none found; check score against worst existing neighbor
bound.set(neighbors.score[0]);
if (bound.check(neighbors.score[maxConn])) {
// drop the new neighbor; it is not competitive and there were no diversity failures
neighbors.removeLast();
return;
} else {
replacePoint = 0;
}
}
neighbors.node[replacePoint] = neighbors.node[maxConn];
neighbors.score[replacePoint] = neighbors.score[maxConn];
neighbors.removeLast();
}
// scan neighbors looking for diversity violations
private int findNonDiverse(NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i >= 0; i--) {
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
// them, drop it
int nbrNode = neighbors.node[i];
/**
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i > 0; i--) {
int cNode = neighbors.node[i];
float[] cVector = vectorValues.vectorValue(cNode);
bound.set(neighbors.score[i]);
float[] nbrVector = vectorValues.vectorValue(nbrNode);
for (int j = maxConn; j > i; j--) {
// check the candidate against its better-scoring neighbors
for (int j = i - 1; j >= 0; j--) {
float diversityCheck =
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
// node i is too similar to node j given its score relative to the base node
if (bound.check(diversityCheck) == false) {
// node j is too similar to node i given its score relative to the base node
// replace it with the new node, which is at [maxConn]
return i;
}
}
}
return -1;
return neighbors.size() - 1;
}
private static int getRandomGraphLevel(double ml, SplittableRandom random) {

View File

@ -17,36 +17,67 @@
package org.apache.lucene.util.hnsw;
import java.util.Arrays;
import org.apache.lucene.util.ArrayUtil;
/**
* NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
* of growable arrays.
* of growable arrays. Nodes are arranged in the sorted order of their scores in descending order
* (if scoresDescOrder is true), or in the ascending order of their scores (if scoresDescOrder is
* false)
*
* @lucene.internal
*/
public class NeighborArray {
private final boolean scoresDescOrder;
private int size;
float[] score;
int[] node;
public NeighborArray(int maxSize) {
public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];
score = new float[maxSize];
this.scoresDescOrder = descOrder;
}
/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
* nodes.
*/
public void add(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
score = ArrayUtil.growExact(score, node.length);
}
if (size > 0) {
float previousScore = score[size - 1];
assert ((scoresDescOrder && (previousScore >= newScore))
|| (scoresDescOrder == false && (previousScore <= newScore)))
: "Nodes are added in the incorrect order!";
}
node[size] = newNode;
score[size] = newScore;
++size;
}
/** Add a new node to the NeighborArray into a correct sort position according to its score. */
public void insertSorted(int newNode, float newScore) {
if (size == node.length - 1) {
node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
score = ArrayUtil.growExact(score, node.length);
}
int insertionPoint =
scoresDescOrder
? descSortFindRightMostInsertionPoint(newScore)
: ascSortFindRightMostInsertionPoint(newScore);
System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
node[insertionPoint] = newNode;
score[insertionPoint] = newScore;
++size;
}
public int size() {
return size;
}
@ -72,8 +103,39 @@ public class NeighborArray {
size--;
}
public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx);
System.arraycopy(score, idx + 1, score, idx, size - idx);
size--;
}
@Override
public String toString() {
return "NeighborArray[" + size + "]";
}
private int ascSortFindRightMostInsertionPoint(float newScore) {
int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
if (insertionPoint >= 0) {
// find the right most position with the same score
while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
insertionPoint++;
}
insertionPoint++;
} else {
insertionPoint = -insertionPoint - 1;
}
return insertionPoint;
}
private int descSortFindRightMostInsertionPoint(float newScore) {
int start = 0;
int end = size - 1;
while (start <= end) {
int mid = (start + end) / 2;
if (score[mid] < newScore) end = mid - 1;
else start = mid + 1;
}
return start;
}
}

View File

@ -29,7 +29,7 @@ import org.apache.lucene.util.NumericUtils;
*/
public class NeighborQueue {
private static enum Order {
private enum Order {
NATURAL {
@Override
long apply(long v) {

View File

@ -31,6 +31,7 @@ import org.apache.lucene.util.ArrayUtil;
public final class OnHeapHnswGraph extends HnswGraph {
private final int maxConn;
private final boolean similarityReversed;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
@ -49,8 +50,9 @@ public final class OnHeapHnswGraph extends HnswGraph {
private int upto;
private NeighborArray cur;
OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) {
this.maxConn = maxConn;
this.similarityReversed = similarityReversed;
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
@ -59,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
// Typically with diversity criteria we see nodes not fully occupied;
// average fanout seems to be about 1/2 maxConn.
// There is some indexing time penalty for under-allocating, but saves RAM
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false));
}
this.nodesByLevel = new ArrayList<>(numLevels);
@ -120,7 +122,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
}
}
graph.get(level).add(new NeighborArray(maxConn + 1));
graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false));
}
@Override

View File

@ -276,7 +276,8 @@ public class KnnGraphTester {
for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(0, i);
System.out.printf(Locale.ROOT, "%5d", i);
NeighborArray sorted = new NeighborArray(neighbors.size());
NeighborArray sorted =
new NeighborArray(neighbors.size(), similarityFunction.reversed == false);
for (int j = 0; j < neighbors.size(); j++) {
int node = neighbors.node[j];
float score = neighbors.score[j];

View File

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestNeighborArray extends LuceneTestCase {
public void testScoresDescOrder() {
NeighborArray neighbors = new NeighborArray(10, true);
neighbors.add(0, 1);
neighbors.add(1, 0.8f);
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
neighbors.insertSorted(3, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
asserNodesEqual(new int[] {0, 3, 1}, neighbors);
neighbors.insertSorted(4, 1f);
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
neighbors.insertSorted(5, 1.1f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
neighbors.insertSorted(6, 0.8f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
neighbors.insertSorted(7, 0.8f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
neighbors.removeIndex(2);
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
neighbors.removeIndex(0);
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
neighbors.removeIndex(4);
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
neighbors.removeLast();
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
asserNodesEqual(new int[] {0, 3, 1}, neighbors);
neighbors.insertSorted(8, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
}
public void testScoresAscOrder() {
NeighborArray neighbors = new NeighborArray(10, false);
neighbors.add(0, 0.1f);
neighbors.add(1, 0.3f);
AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
neighbors.insertSorted(3, 0.3f);
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {0, 1, 3}, neighbors);
neighbors.insertSorted(4, 0.2f);
assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
neighbors.insertSorted(5, 0.05f);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
neighbors.insertSorted(6, 0.2f);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
neighbors.insertSorted(7, 0.2f);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
neighbors.removeIndex(2);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
neighbors.removeIndex(0);
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
neighbors.removeIndex(4);
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors);
asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
neighbors.removeLast();
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors);
asserNodesEqual(new int[] {0, 6, 7}, neighbors);
neighbors.insertSorted(8, 0.01f);
assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors);
asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
}
private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
for (int i = 0; i < scores.length; i++) {
assertEquals(scores[i], neighbors.score[i], 0.01f);
}
}
private void asserNodesEqual(int[] nodes, NeighborArray neighbors) {
for (int i = 0; i < nodes.length; i++) {
assertEquals(nodes[i], neighbors.node[i]);
}
}
}