LUCENE-10593: VectorSimilarityFunction reverse removal (#926)

* Vector Similarity Function reverse property removed

* NeighborQueue tie-breaking fixed (node id + node score encoding)

* NeighborQueue readability refactor

* BoundChecker removal (now it's only in backward-codecs)
This commit is contained in:
Alessandro Benedetti 2022-06-28 15:33:11 +02:00 committed by GitHub
parent 3e74ebbc0d
commit 8cf694fed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 245 additions and 163 deletions

View File

@ -107,6 +107,8 @@ Optimizations
* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh)
* LUCENE-10593: Vector similarity function and NeighborQueue reverse removal. (Alessandro Benedetti)
Bug Fixes
---------------------

View File

@ -14,17 +14,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
package org.apache.lucene.backward_codecs.lucene90;
/**
* A helper class for an hnsw graph that serves as a comparator of the currently set bound value
* with a new value.
*/
public abstract class BoundsChecker {
public abstract class Lucene90BoundsChecker {
float bound;
/** Default Constructor */
public Lucene90BoundsChecker() {}
/** Update the bound if sample is better */
public abstract void update(float sample);
@ -33,10 +35,21 @@ public abstract class BoundsChecker {
bound = sample;
}
/** @return whether the sample exceeds (is worse than) the bound */
/**
* Check the sample
*
* @param sample a score
* @return whether the sample exceeds (is worse than) the bound
*/
public abstract boolean check(float sample);
public static BoundsChecker create(boolean reversed) {
/**
* Create a min or max bound checker
*
* @param reversed true for the min and false for the max
* @return the bound checker
*/
public static Lucene90BoundsChecker create(boolean reversed) {
if (reversed) {
return new Min();
} else {
@ -48,7 +61,7 @@ public abstract class BoundsChecker {
* A helper class for an hnsw graph that serves as a comparator of the currently set maximum value
* with a new value.
*/
public static class Max extends BoundsChecker {
public static class Max extends Lucene90BoundsChecker {
Max() {
bound = Float.NEGATIVE_INFINITY;
}
@ -70,7 +83,7 @@ public abstract class BoundsChecker {
* A helper class for an hnsw graph that serves as a comparator of the currently set minimum value
* with a new value.
*/
public static class Min extends BoundsChecker {
public static class Min extends Lucene90BoundsChecker {
Min() {
bound = Float.POSITIVE_INFINITY;

View File

@ -25,7 +25,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
@ -51,7 +50,7 @@ public final class Lucene90HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
private final BoundsChecker bound;
private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault();
@ -91,7 +90,7 @@ public final class Lucene90HnswGraphBuilder {
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
bound = BoundsChecker.create(similarityFunction.reversed);
bound = Lucene90BoundsChecker.create(false);
random = new SplittableRandom(seed);
scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
}
@ -234,9 +233,9 @@ public final class Lucene90HnswGraphBuilder {
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
float diversityCheck =
float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node()[i]));
if (bound.check(diversityCheck) == false) {
if (bound.check(neighborSimilarity) == false) {
return false;
}
}
@ -267,13 +266,14 @@ public final class Lucene90HnswGraphBuilder {
for (int i = neighbors.size() - 1; i >= 0; i--) {
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
// them, drop it
int nbrNode = neighbors.node()[i];
int neighborId = neighbors.node()[i];
bound.set(neighbors.score()[i]);
float[] nbrVector = vectorValues.vectorValue(nbrNode);
float[] neighborVector = vectorValues.vectorValue(neighborId);
for (int j = maxConn; j > i; j--) {
float diversityCheck =
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node()[j]));
if (bound.check(diversityCheck) == false) {
float neighborSimilarity =
similarityFunction.compare(
neighborVector, buildVectors.vectorValue(neighbors.node()[j]));
if (bound.check(neighborSimilarity) == false) {
// node j is too similar to node i given its score relative to the base node
// replace it with the new node, which is at [maxConn]
return i;

View File

@ -266,9 +266,9 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], minSimilarity);
}
TotalHits.Relation relation =
results.incomplete()

View File

@ -27,7 +27,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
@ -85,9 +84,9 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
int size = graphValues.size();
// MIN heap, holding the top results
NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
NeighborQueue results = new NeighborQueue(numSeed, false);
// MAX heap, from which to pull the candidate nodes
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
NeighborQueue candidates = new NeighborQueue(numSeed, true);
int numVisited = 0;
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
@ -114,13 +113,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
// Set the bound to the worst current result and below reject any newly-generated candidates
// failing
// to exceed this bound
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
Lucene90BoundsChecker bound = Lucene90BoundsChecker.create(false);
bound.set(results.topScore());
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
float topCandidateScore = candidates.topScore();
float topCandidateSimilarity = candidates.topScore();
if (results.size() >= topK) {
if (bound.check(topCandidateScore)) {
if (bound.check(topCandidateSimilarity)) {
break;
}
}
@ -138,11 +137,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
if (results.size() < numSeed || bound.check(score) == false) {
candidates.add(friendOrd, score);
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
if (results.size() < numSeed || bound.check(friendSimilarity) == false) {
candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
results.insertWithOverflow(friendOrd, score);
results.insertWithOverflow(friendOrd, friendSimilarity);
bound.set(results.topScore());
}
}

View File

@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene91;
/**
* A helper class for an hnsw graph that serves as a comparator of the currently set bound value
* with a new value.
*/
public abstract class Lucene91BoundsChecker {
float bound;
/** Default Constructor */
public Lucene91BoundsChecker() {}
/** Update the bound if sample is better */
public abstract void update(float sample);
/** Update the bound unconditionally */
public void set(float sample) {
bound = sample;
}
/**
* Check the sample
*
* @param sample a score
* @return whether the sample exceeds (is worse than) the bound
*/
public abstract boolean check(float sample);
/**
* Create a min or max bound checker
*
* @param reversed true for the min and false for the max
* @return the bound checker
*/
public static Lucene91BoundsChecker create(boolean reversed) {
if (reversed) {
return new Min();
} else {
return new Max();
}
}
/**
* A helper class for an hnsw graph that serves as a comparator of the currently set maximum value
* with a new value.
*/
public static class Max extends Lucene91BoundsChecker {
Max() {
bound = Float.NEGATIVE_INFINITY;
}
@Override
public void update(float sample) {
if (sample > bound) {
bound = sample;
}
}
@Override
public boolean check(float sample) {
return sample < bound;
}
}
/**
* A helper class for an hnsw graph that serves as a comparator of the currently set minimum value
* with a new value.
*/
public static class Min extends Lucene91BoundsChecker {
Min() {
bound = Float.POSITIVE_INFINITY;
}
@Override
public void update(float sample) {
if (sample < bound) {
bound = sample;
}
}
@Override
public boolean check(float sample) {
return sample > bound;
}
}
}

View File

@ -28,7 +28,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
@ -55,7 +54,7 @@ public final class Lucene91HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
private final BoundsChecker bound;
private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
final Lucene91OnHeapHnswGraph hnsw;
@ -104,9 +103,9 @@ public final class Lucene91HnswGraphBuilder {
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
bound = BoundsChecker.create(similarityFunction.reversed);
bound = Lucene91BoundsChecker.create(false);
scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1));
}
@ -231,8 +230,8 @@ public final class Lucene91HnswGraphBuilder {
// extract all the Neighbors from the queue into an array; these will now be
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float score = candidates.topScore();
scratch.add(candidates.pop(), score);
float similarity = candidates.topScore();
scratch.add(candidates.pop(), similarity);
}
}
@ -253,9 +252,9 @@ public final class Lucene91HnswGraphBuilder {
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
float diversityCheck =
float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
if (bound.check(diversityCheck) == false) {
if (bound.check(neighborSimilarity) == false) {
return false;
}
}
@ -286,13 +285,13 @@ public final class Lucene91HnswGraphBuilder {
for (int i = neighbors.size() - 1; i >= 0; i--) {
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
// them, drop it
int nbrNode = neighbors.node[i];
int neighborId = neighbors.node[i];
bound.set(neighbors.score[i]);
float[] nbrVector = vectorValues.vectorValue(nbrNode);
float[] neighborVector = vectorValues.vectorValue(neighborId);
for (int j = maxConn; j > i; j--) {
float diversityCheck =
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
if (bound.check(diversityCheck) == false) {
float neighborSimilarity =
similarityFunction.compare(neighborVector, buildVectors.vectorValue(neighbors.node[j]));
if (bound.check(neighborSimilarity) == false) {
// node j is too similar to node i given its score relative to the base node
// replace it with the new node, which is at [maxConn]
return i;

View File

@ -253,9 +253,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), score);
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), minSimilarity);
}
TotalHits.Relation relation =

View File

@ -246,9 +246,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), minSimilarity);
}
TotalHits.Relation relation =

View File

@ -170,7 +170,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
float[] vector = values.vectorValue();
float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
float score = vectorSimilarity.compare(vector, target);
topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++;
}

View File

@ -246,7 +246,7 @@ public final class Lucene93HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
float score = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
}

View File

@ -26,15 +26,10 @@ import static org.apache.lucene.util.VectorUtil.*;
public enum VectorSimilarityFunction {
/** Euclidean distance */
EUCLIDEAN(true) {
EUCLIDEAN {
@Override
public float compare(float[] v1, float[] v2) {
return squareDistance(v1, v2);
}
@Override
public float convertToScore(float similarity) {
return 1 / (1 + similarity);
return 1 / (1 + squareDistance(v1, v2));
}
},
@ -47,12 +42,7 @@ public enum VectorSimilarityFunction {
DOT_PRODUCT {
@Override
public float compare(float[] v1, float[] v2) {
return dotProduct(v1, v2);
}
@Override
public float convertToScore(float similarity) {
return (1 + similarity) / 2;
return (1 + dotProduct(v1, v2)) / 2;
}
},
@ -60,50 +50,22 @@ public enum VectorSimilarityFunction {
* Cosine similarity. NOTE: the preferred way to perform cosine similarity is to normalize all
* vectors to unit length, and instead use {@link VectorSimilarityFunction#DOT_PRODUCT}. You
* should only use this function if you need to preserve the original vectors and cannot normalize
* them in advance.
* them in advance. The similarity score is normalised to assure it is positive.
*/
COSINE {
@Override
public float compare(float[] v1, float[] v2) {
return cosine(v1, v2);
}
@Override
public float convertToScore(float similarity) {
return (1 + similarity) / 2;
return (1 + cosine(v1, v2)) / 2;
}
};
/**
* If true, the scores associated with vector comparisons are nonnegative and in reverse order;
* that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
* represent more similar vectors, and scores may be negative or positive.
*/
public final boolean reversed;
VectorSimilarityFunction(boolean reversed) {
this.reversed = reversed;
}
VectorSimilarityFunction() {
reversed = false;
}
/**
* Calculates a similarity score between the two vectors with a specified function.
* Calculates a similarity score between the two vectors with a specified function. Higher
* similarity scores correspond to closer vectors.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(float[] v1, float[] v2);
/**
* Converts similarity scores used (may be negative, reversed, etc) into document scores, which
* must be positive, with higher scores representing better matches.
*
* @param similarity the raw internal score as returned by {@link #compare(float[], float[])}.
* @return normalizedSimilarity
*/
public abstract float convertToScore(float similarity);
}

View File

@ -197,7 +197,7 @@ public class KnnVectorQuery extends Query {
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
float score = similarityFunction.convertToScore(similarityFunction.compare(vector, target));
float score = similarityFunction.compare(vector, target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;

View File

@ -51,7 +51,6 @@ public final class HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
private final BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
final OnHeapHnswGraph hnsw;
@ -96,15 +95,14 @@ public final class HnswGraphBuilder {
this.ml = 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode, similarityFunction.reversed);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
bound = BoundsChecker.create(similarityFunction.reversed);
// in scratch we store candidates in reverse order: worse candidates are first
scratch = new NeighborArray(Math.max(beamWidth, M + 1), similarityFunction.reversed);
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
}
/**
@ -225,8 +223,8 @@ public final class HnswGraphBuilder {
// extract all the Neighbors from the queue into an array; these will now be
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float score = candidates.topScore();
scratch.add(candidates.pop(), score);
float maxSimilarity = candidates.topScore();
scratch.add(candidates.pop(), maxSimilarity);
}
}
@ -245,11 +243,10 @@ public final class HnswGraphBuilder {
NeighborArray neighbors,
RandomAccessVectorValues vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
float diversityCheck =
float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
if (bound.check(diversityCheck) == false) {
if (neighborSimilarity >= score) {
return false;
}
}
@ -261,16 +258,17 @@ public final class HnswGraphBuilder {
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity;
for (int i = neighbors.size() - 1; i > 0; i--) {
int cNode = neighbors.node[i];
float[] cVector = vectorValues.vectorValue(cNode);
bound.set(neighbors.score[i]);
minAcceptedSimilarity = neighbors.score[i];
// check the candidate against its better-scoring neighbors
for (int j = i - 1; j >= 0; j--) {
float diversityCheck =
float neighborSimilarity =
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
// node i is too similar to node j given its score relative to the base node
if (bound.check(diversityCheck) == false) {
if (neighborSimilarity >= minAcceptedSimilarity) {
return i;
}
}

View File

@ -80,7 +80,7 @@ public final class HnswGraphSearcher {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(topK, similarityFunction.reversed == false),
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
@ -139,7 +139,7 @@ public final class HnswGraphSearcher {
int visitedLimit)
throws IOException {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
NeighborQueue results = new NeighborQueue(topK, false);
clearScratchState();
int numVisited = 0;
@ -160,14 +160,14 @@ public final class HnswGraphSearcher {
// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
if (results.size() >= topK) {
bound.set(results.topScore());
minAcceptedSimilarity = results.topScore();
}
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
float topCandidateScore = candidates.topScore();
if (bound.check(topCandidateScore)) {
float topCandidateSimilarity = candidates.topScore();
if (topCandidateSimilarity < minAcceptedSimilarity) {
break;
}
@ -184,13 +184,13 @@ public final class HnswGraphSearcher {
results.markIncomplete();
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
numVisited++;
if (bound.check(score) == false) {
candidates.add(friendOrd, score);
if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
if (results.insertWithOverflow(friendOrd, score) && results.size() >= topK) {
bound.set(results.topScore());
if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
minAcceptedSimilarity = results.topScore();
}
}
}

View File

@ -30,13 +30,13 @@ import org.apache.lucene.util.NumericUtils;
public class NeighborQueue {
private enum Order {
NATURAL {
MIN_HEAP {
@Override
long apply(long v) {
return v;
}
},
REVERSED {
MAX_HEAP {
@Override
long apply(long v) {
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
@ -56,9 +56,9 @@ public class NeighborQueue {
// Whether the search stopped early because it reached the visited nodes limit
private boolean incomplete;
public NeighborQueue(int initialSize, boolean reversed) {
public NeighborQueue(int initialSize, boolean maxHeap) {
this.heap = new LongHeap(initialSize);
this.order = reversed ? Order.REVERSED : Order.NATURAL;
this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP;
}
/** @return the number of elements in the heap */
@ -89,32 +89,66 @@ public class NeighborQueue {
return heap.insertWithOverflow(encode(newNode, newScore));
}
/**
* Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule
* that when two scores are equals, the smaller node ID must win.
*
* <p>The most significant 32 bits represent the float score, encoded as a sortable int.
*
* <p>The less significant 32 bits represent the node ID.
*
* <p>The bits representing the node ID are complemented to guarantee the win for the smaller node
* Id.
*
* <p>The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that
* has
*
* <p>The most significant 32 bits to 0
*
* <p>The less significant 32 bits represent the node ID.
*
* @param node the node ID
* @param score the node score
* @return the encoded score, node ID
*/
private long encode(int node, float score) {
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
return order.apply(
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
}
private float decodeScore(long heapValue) {
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
}
private int decodeNodeId(long heapValue) {
return (int) ~(order.apply(heapValue));
}
/** Removes the top element and returns its node id. */
public int pop() {
return (int) order.apply(heap.pop());
return decodeNodeId(heap.pop());
}
public int[] nodes() {
int size = size();
int[] nodes = new int[size];
for (int i = 0; i < size; i++) {
nodes[i] = (int) order.apply(heap.get(i + 1));
nodes[i] = decodeNodeId(heap.get(i + 1));
}
return nodes;
}
/** Returns the top element's node id. */
public int topNode() {
return (int) order.apply(heap.top());
return decodeNodeId(heap.top());
}
/** Returns the top element's node score. */
/**
* Returns the top element's node score. For the min heap this is the minimum score. For the max
* heap this is the maximum score.
*/
public float topScore() {
return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
return decodeScore(heap.top());
}
public void clear() {

View File

@ -30,7 +30,6 @@ import org.apache.lucene.util.ArrayUtil;
*/
public final class OnHeapHnswGraph extends HnswGraph {
private final boolean similarityReversed;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
@ -52,8 +51,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
private int upto;
private NeighborArray cur;
OnHeapHnswGraph(int M, int levelOfFirstNode, boolean similarityReversed) {
this.similarityReversed = similarityReversed;
OnHeapHnswGraph(int M, int levelOfFirstNode) {
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
@ -63,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
this.nsize0 = (M * 2 + 1);
for (int l = 0; l < numLevels; l++) {
graph.add(new ArrayList<>());
graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, similarityReversed == false));
graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true));
}
this.nodesByLevel = new ArrayList<>(numLevels);
@ -123,9 +121,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
}
}
}
graph
.get(level)
.add(new NeighborArray(level == 0 ? nsize0 : nsize, similarityReversed == false));
graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true));
}
@Override

View File

@ -273,8 +273,7 @@ public class KnnGraphTester {
for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(0, i);
System.out.printf(Locale.ROOT, "%5d", i);
NeighborArray sorted =
new NeighborArray(neighbors.size(), similarityFunction.reversed == false);
NeighborArray sorted = new NeighborArray(neighbors.size(), true);
for (int j = 0; j < neighbors.size(); j++) {
int node = neighbors.node[j];
float score = neighbors.score[j];
@ -555,7 +554,7 @@ public class KnnGraphTester {
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
NeighborQueue queue = new NeighborQueue(topK, similarityFunction.reversed);
NeighborQueue queue = new NeighborQueue(topK, false);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = similarityFunction.compare(query, vector);

View File

@ -309,30 +309,6 @@ public class TestHnswGraph extends LuceneTestCase {
assertTrue(nn.visitedCount() <= visitedLimit);
}
public void testBoundsCheckerMax() {
BoundsChecker max = BoundsChecker.create(false);
float f = random().nextFloat() - 0.5f;
// any float > -MAX_VALUE is in bounds
assertFalse(max.check(f));
// f is now the bound (minus some delta)
max.update(f);
assertFalse(max.check(f)); // f is not out of bounds
assertFalse(max.check(f + 1)); // anything greater than f is in bounds
assertTrue(max.check(f - 1e-5f)); // delta is zero initially
}
public void testBoundsCheckerMin() {
BoundsChecker min = BoundsChecker.create(true);
float f = random().nextFloat() - 0.5f;
// any float < MAX_VALUE is in bounds
assertFalse(min.check(f));
// f is now the bound (minus some delta)
min.update(f);
assertFalse(min.check(f)); // f is not out of bounds
assertFalse(min.check(f - 1)); // anything less than f is in bounds
assertTrue(min.check(f + 1e-5f)); // delta is zero initially
}
public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
expectThrows(
@ -441,7 +417,7 @@ public class TestHnswGraph extends LuceneTestCase {
while (actual.size() > topK) {
actual.pop();
}
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));