diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index b59ba3b4a7a..0e8afd822b2 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -26,7 +26,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.hnsw.BoundsChecker; -import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.NeighborQueue; /** @@ -47,7 +46,7 @@ public final class Lucene90HnswGraphBuilder { private final int maxConn; private final int beamWidth; - private final NeighborArray scratch; + private final Lucene90NeighborArray scratch; private final VectorSimilarityFunction similarityFunction; private final RandomAccessVectorValues vectorValues; @@ -93,7 +92,7 @@ public final class Lucene90HnswGraphBuilder { this.hnsw = new Lucene90OnHeapHnswGraph(maxConn); bound = BoundsChecker.create(similarityFunction.reversed); random = new SplittableRandom(seed); - scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1)); + scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1)); } /** @@ -173,7 +172,7 @@ public final class Lucene90HnswGraphBuilder { * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, * since the node is new and has no prior neighbors). */ - NeighborArray neighbors = hnsw.getNeighbors(node); + Lucene90NeighborArray neighbors = hnsw.getNeighbors(node); assert neighbors.size() == 0; // new node popToScratch(candidates); selectDiverse(neighbors, scratch); @@ -183,7 +182,7 @@ public final class Lucene90HnswGraphBuilder { int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node()[i]; - NeighborArray nbrNbr = hnsw.getNeighbors(nbr); + Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr); nbrNbr.add(node, neighbors.score()[i]); if (nbrNbr.size() > maxConn) { diversityUpdate(nbrNbr); @@ -191,7 +190,8 @@ public final class Lucene90HnswGraphBuilder { } } - private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException { + private void selectDiverse(Lucene90NeighborArray neighbors, Lucene90NeighborArray candidates) + throws IOException { // Select the best maxConn neighbors of the new node, applying the diversity heuristic for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) { // compare each neighbor (in distance order) against the closer neighbors selected so far, @@ -228,7 +228,7 @@ public final class Lucene90HnswGraphBuilder { private boolean diversityCheck( float[] candidate, float score, - NeighborArray neighbors, + Lucene90NeighborArray neighbors, RandomAccessVectorValues vectorValues) throws IOException { bound.set(score); @@ -242,7 +242,7 @@ public final class Lucene90HnswGraphBuilder { return true; } - private void diversityUpdate(NeighborArray neighbors) throws IOException { + private void diversityUpdate(Lucene90NeighborArray neighbors) throws IOException { assert neighbors.size() == maxConn + 1; int replacePoint = findNonDiverse(neighbors); if (replacePoint == -1) { @@ -262,7 +262,7 @@ public final class Lucene90HnswGraphBuilder { } // scan neighbors looking for diversity violations - private int findNonDiverse(NeighborArray neighbors) throws IOException { + private int findNonDiverse(Lucene90NeighborArray neighbors) throws IOException { for (int i = neighbors.size() - 1; i >= 0; i--) { // check each neighbor against its better-scoring neighbors. If it fails diversity check with // them, drop it diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90NeighborArray.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90NeighborArray.java new file mode 100644 index 00000000000..e2412fcd7da --- /dev/null +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90NeighborArray.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.backward_codecs.lucene90; + +import org.apache.lucene.util.ArrayUtil; + +/** + * NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair + * of growable arrays. + * + * @lucene.internal + */ +public class Lucene90NeighborArray { + + private int size; + + float[] score; + int[] node; + + /** Create a neighbour array with the given initial size */ + public Lucene90NeighborArray(int maxSize) { + node = new int[maxSize]; + score = new float[maxSize]; + } + + /** Add a new node with a score */ + public void add(int newNode, float newScore) { + if (size == node.length - 1) { + node = ArrayUtil.grow(node, (size + 1) * 3 / 2); + score = ArrayUtil.growExact(score, node.length); + } + node[size] = newNode; + score[size] = newScore; + ++size; + } + + /** Get the size, the number of nodes added so far */ + public int size() { + return size; + } + + /** + * Direct access to the internal list of node ids; provided for efficient writing of the graph + * + * @lucene.internal + */ + public int[] node() { + return node; + } + + /** + * Direct access to the internal list of scores + * + * @lucene.internal + */ + public float[] score() { + return score; + } + + /** Clear all the nodes in the array */ + public void clear() { + size = 0; + } + + /** Remove the last nodes from the array */ + public void removeLast() { + size--; + } + + @Override + public String toString() { + return "NeighborArray[" + size + "]"; + } +} diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index 9de59301abb..6457b8071e9 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -29,7 +29,6 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.SparseFixedBitSet; import org.apache.lucene.util.hnsw.BoundsChecker; import org.apache.lucene.util.hnsw.HnswGraph; -import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.NeighborQueue; /** @@ -43,17 +42,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { // Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to // HnswBuilder, and the // node values are the ordinals of those vectors. - private final List graph; + private final List graph; // KnnGraphValues iterator members private int upto; - private NeighborArray cur; + private Lucene90NeighborArray cur; Lucene90OnHeapHnswGraph(int maxConn) { graph = new ArrayList<>(); // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM - graph.add(new NeighborArray(Math.max(32, maxConn / 4))); + graph.add(new Lucene90NeighborArray(Math.max(32, maxConn / 4))); this.maxConn = maxConn; } @@ -162,7 +161,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { * * @param node the node whose neighbors are returned */ - public NeighborArray getNeighbors(int node) { + public Lucene90NeighborArray getNeighbors(int node) { return graph.get(node); } @@ -172,7 +171,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { } int addNode() { - graph.add(new NeighborArray(maxConn + 1)); + graph.add(new Lucene90NeighborArray(maxConn + 1)); return graph.size() - 1; } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index a71f5efb14f..44e46ab9b16 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -35,7 +35,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.NeighborArray; /** * Writes vector values and knn graphs to index segments. @@ -247,7 +246,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter { // write graph offsets[ord] = graphData.getFilePointer() - graphDataOffset; - NeighborArray neighbors = graph.getNeighbors(ord); + Lucene90NeighborArray neighbors = graph.getNeighbors(ord); int size = neighbors.size(); // Destructively modify; it's ok we are discarding it after this diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index dcd1c25a77f..63fe10fba4a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -95,14 +95,15 @@ public final class HnswGraphBuilder { this.ml = 1 / Math.log(1.0 * maxConn); this.random = new SplittableRandom(seed); int levelOfFirstNode = getRandomGraphLevel(ml, random); - this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode); + this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed); this.graphSearcher = new HnswGraphSearcher( similarityFunction, new NeighborQueue(beamWidth, similarityFunction.reversed == false), new FixedBitSet(vectorValues.size())); bound = BoundsChecker.create(similarityFunction.reversed); - scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1)); + // in scratch we store candidates in reverse order: worse candidates are first + scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed); } /** @@ -176,11 +177,6 @@ public final class HnswGraphBuilder { return now; } - /* TODO: we are not maintaining nodes in strict score order; the forward links - * are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should - * work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap? - * But first we should just see if sorting makes a significant difference. - */ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) throws IOException { /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it @@ -190,7 +186,7 @@ public final class HnswGraphBuilder { NeighborArray neighbors = hnsw.getNeighbors(level, node); assert neighbors.size() == 0; // new node popToScratch(candidates); - selectDiverse(neighbors, scratch); + selectAndLinkDiverse(neighbors, scratch); // Link the selected nodes to the new node, and the new node to the selected nodes (again // applying diversity heuristic) @@ -198,14 +194,16 @@ public final class HnswGraphBuilder { for (int i = 0; i < size; i++) { int nbr = neighbors.node[i]; NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); - nbrNbr.add(node, neighbors.score[i]); + nbrNbr.insertSorted(node, neighbors.score[i]); if (nbrNbr.size() > maxConn) { - diversityUpdate(nbrNbr); + int indexToRemove = findWorstNonDiverse(nbrNbr); + nbrNbr.removeIndex(indexToRemove); } } } - private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException { + private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates) + throws IOException { // Select the best maxConn neighbors of the new node, applying the diversity heuristic for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) { // compare each neighbor (in distance order) against the closer neighbors selected so far, @@ -256,44 +254,26 @@ public final class HnswGraphBuilder { return true; } - private void diversityUpdate(NeighborArray neighbors) throws IOException { - assert neighbors.size() == maxConn + 1; - int replacePoint = findNonDiverse(neighbors); - if (replacePoint == -1) { - // none found; check score against worst existing neighbor - bound.set(neighbors.score[0]); - if (bound.check(neighbors.score[maxConn])) { - // drop the new neighbor; it is not competitive and there were no diversity failures - neighbors.removeLast(); - return; - } else { - replacePoint = 0; - } - } - neighbors.node[replacePoint] = neighbors.node[maxConn]; - neighbors.score[replacePoint] = neighbors.score[maxConn]; - neighbors.removeLast(); - } - - // scan neighbors looking for diversity violations - private int findNonDiverse(NeighborArray neighbors) throws IOException { - for (int i = neighbors.size() - 1; i >= 0; i--) { - // check each neighbor against its better-scoring neighbors. If it fails diversity check with - // them, drop it - int nbrNode = neighbors.node[i]; + /** + * Find first non-diverse neighbour among the list of neighbors starting from the most distant + * neighbours + */ + private int findWorstNonDiverse(NeighborArray neighbors) throws IOException { + for (int i = neighbors.size() - 1; i > 0; i--) { + int cNode = neighbors.node[i]; + float[] cVector = vectorValues.vectorValue(cNode); bound.set(neighbors.score[i]); - float[] nbrVector = vectorValues.vectorValue(nbrNode); - for (int j = maxConn; j > i; j--) { + // check the candidate against its better-scoring neighbors + for (int j = i - 1; j >= 0; j--) { float diversityCheck = - similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j])); + similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j])); + // node i is too similar to node j given its score relative to the base node if (bound.check(diversityCheck) == false) { - // node j is too similar to node i given its score relative to the base node - // replace it with the new node, which is at [maxConn] return i; } } } - return -1; + return neighbors.size() - 1; } private static int getRandomGraphLevel(double ml, SplittableRandom random) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index 40125750309..78224ed2358 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -17,36 +17,67 @@ package org.apache.lucene.util.hnsw; +import java.util.Arrays; import org.apache.lucene.util.ArrayUtil; /** * NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair - * of growable arrays. + * of growable arrays. Nodes are arranged in the sorted order of their scores in descending order + * (if scoresDescOrder is true), or in the ascending order of their scores (if scoresDescOrder is + * false) * * @lucene.internal */ public class NeighborArray { - + private final boolean scoresDescOrder; private int size; float[] score; int[] node; - public NeighborArray(int maxSize) { + public NeighborArray(int maxSize, boolean descOrder) { node = new int[maxSize]; score = new float[maxSize]; + this.scoresDescOrder = descOrder; } + /** + * Add a new node to the NeighborArray. The new node must be worse than all previously stored + * nodes. + */ public void add(int newNode, float newScore) { if (size == node.length - 1) { node = ArrayUtil.grow(node, (size + 1) * 3 / 2); score = ArrayUtil.growExact(score, node.length); } + if (size > 0) { + float previousScore = score[size - 1]; + assert ((scoresDescOrder && (previousScore >= newScore)) + || (scoresDescOrder == false && (previousScore <= newScore))) + : "Nodes are added in the incorrect order!"; + } node[size] = newNode; score[size] = newScore; ++size; } + /** Add a new node to the NeighborArray into a correct sort position according to its score. */ + public void insertSorted(int newNode, float newScore) { + if (size == node.length - 1) { + node = ArrayUtil.grow(node, (size + 1) * 3 / 2); + score = ArrayUtil.growExact(score, node.length); + } + int insertionPoint = + scoresDescOrder + ? descSortFindRightMostInsertionPoint(newScore) + : ascSortFindRightMostInsertionPoint(newScore); + System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint); + System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint); + node[insertionPoint] = newNode; + score[insertionPoint] = newScore; + ++size; + } + public int size() { return size; } @@ -72,8 +103,39 @@ public class NeighborArray { size--; } + public void removeIndex(int idx) { + System.arraycopy(node, idx + 1, node, idx, size - idx); + System.arraycopy(score, idx + 1, score, idx, size - idx); + size--; + } + @Override public String toString() { return "NeighborArray[" + size + "]"; } + + private int ascSortFindRightMostInsertionPoint(float newScore) { + int insertionPoint = Arrays.binarySearch(score, 0, size, newScore); + if (insertionPoint >= 0) { + // find the right most position with the same score + while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) { + insertionPoint++; + } + insertionPoint++; + } else { + insertionPoint = -insertionPoint - 1; + } + return insertionPoint; + } + + private int descSortFindRightMostInsertionPoint(float newScore) { + int start = 0; + int end = size - 1; + while (start <= end) { + int mid = (start + end) / 2; + if (score[mid] < newScore) end = mid - 1; + else start = mid + 1; + } + return start; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java index 2d18cbb99ca..cb58c608f61 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java @@ -29,7 +29,7 @@ import org.apache.lucene.util.NumericUtils; */ public class NeighborQueue { - private static enum Order { + private enum Order { NATURAL { @Override long apply(long v) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 09f8afa7aa7..08cecd1f8ac 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -31,6 +31,7 @@ import org.apache.lucene.util.ArrayUtil; public final class OnHeapHnswGraph extends HnswGraph { private final int maxConn; + private final boolean similarityReversed; private int numLevels; // the current number of levels in the graph private int entryNode; // the current graph entry node on the top level @@ -49,8 +50,9 @@ public final class OnHeapHnswGraph extends HnswGraph { private int upto; private NeighborArray cur; - OnHeapHnswGraph(int maxConn, int levelOfFirstNode) { + OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) { this.maxConn = maxConn; + this.similarityReversed = similarityReversed; this.numLevels = levelOfFirstNode + 1; this.graph = new ArrayList<>(numLevels); this.entryNode = 0; @@ -59,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph { // Typically with diversity criteria we see nodes not fully occupied; // average fanout seems to be about 1/2 maxConn. // There is some indexing time penalty for under-allocating, but saves RAM - graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4))); + graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false)); } this.nodesByLevel = new ArrayList<>(numLevels); @@ -120,7 +122,7 @@ public final class OnHeapHnswGraph extends HnswGraph { } } - graph.get(level).add(new NeighborArray(maxConn + 1)); + graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false)); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java index 822ef78197c..57389d098c9 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -276,7 +276,8 @@ public class KnnGraphTester { for (int i = 0; i < hnsw.size(); i++) { NeighborArray neighbors = hnsw.getNeighbors(0, i); System.out.printf(Locale.ROOT, "%5d", i); - NeighborArray sorted = new NeighborArray(neighbors.size()); + NeighborArray sorted = + new NeighborArray(neighbors.size(), similarityFunction.reversed == false); for (int j = 0; j < neighbors.size(); j++) { int node = neighbors.node[j]; float score = neighbors.score[j]; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java new file mode 100644 index 00000000000..b8ae24f6200 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestNeighborArray extends LuceneTestCase { + + public void testScoresDescOrder() { + NeighborArray neighbors = new NeighborArray(10, true); + neighbors.add(0, 1); + neighbors.add(1, 0.8f); + + AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f)); + assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + + neighbors.insertSorted(3, 0.9f); + assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); + asserNodesEqual(new int[] {0, 3, 1}, neighbors); + + neighbors.insertSorted(4, 1f); + assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors); + asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors); + + neighbors.insertSorted(5, 1.1f); + assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors); + asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); + + neighbors.insertSorted(6, 0.8f); + assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors); + asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); + + neighbors.insertSorted(7, 0.8f); + assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); + asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); + + neighbors.removeIndex(2); + assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); + asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); + + neighbors.removeIndex(0); + assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); + asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); + + neighbors.removeIndex(4); + assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors); + asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors); + + neighbors.removeLast(); + assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); + asserNodesEqual(new int[] {0, 3, 1}, neighbors); + + neighbors.insertSorted(8, 0.9f); + assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors); + asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors); + } + + public void testScoresAscOrder() { + NeighborArray neighbors = new NeighborArray(10, false); + neighbors.add(0, 0.1f); + neighbors.add(1, 0.3f); + + AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f)); + assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + + neighbors.insertSorted(3, 0.3f); + assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {0, 1, 3}, neighbors); + + neighbors.insertSorted(4, 0.2f); + assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors); + + neighbors.insertSorted(5, 0.05f); + assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors); + + neighbors.insertSorted(6, 0.2f); + assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors); + + neighbors.insertSorted(7, 0.2f); + assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors); + + neighbors.removeIndex(2); + assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors); + + neighbors.removeIndex(0); + assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); + asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors); + + neighbors.removeIndex(4); + assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors); + asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors); + + neighbors.removeLast(); + assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors); + asserNodesEqual(new int[] {0, 6, 7}, neighbors); + + neighbors.insertSorted(8, 0.01f); + assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors); + asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors); + } + + private void assertScoresEqual(float[] scores, NeighborArray neighbors) { + for (int i = 0; i < scores.length; i++) { + assertEquals(scores[i], neighbors.score[i], 0.01f); + } + } + + private void asserNodesEqual(int[] nodes, NeighborArray neighbors) { + for (int i = 0; i < nodes.length; i++) { + assertEquals(nodes[i], neighbors.node[i]); + } + } +}