mirror of https://github.com/apache/lucene.git
LUCENE-10593: VectorSimilarityFunction reverse removal (#926)
* Vector Similarity Function reverse property removed * NeighborQueue tie-breaking fixed (node id + node score encoding) * NeighborQueue readability refactor * BoundChecker removal (now it's only in backward-codecs)
This commit is contained in:
parent
3e74ebbc0d
commit
8cf694fed2
|
@ -107,6 +107,8 @@ Optimizations
|
|||
|
||||
* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh)
|
||||
|
||||
* LUCENE-10593: Vector similarity function and NeighborQueue reverse removal. (Alessandro Benedetti)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -14,17 +14,19 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
package org.apache.lucene.backward_codecs.lucene90;
|
||||
|
||||
/**
|
||||
* A helper class for an hnsw graph that serves as a comparator of the currently set bound value
|
||||
* with a new value.
|
||||
*/
|
||||
public abstract class BoundsChecker {
|
||||
public abstract class Lucene90BoundsChecker {
|
||||
|
||||
float bound;
|
||||
|
||||
/** Default Constructor */
|
||||
public Lucene90BoundsChecker() {}
|
||||
|
||||
/** Update the bound if sample is better */
|
||||
public abstract void update(float sample);
|
||||
|
||||
|
@ -33,10 +35,21 @@ public abstract class BoundsChecker {
|
|||
bound = sample;
|
||||
}
|
||||
|
||||
/** @return whether the sample exceeds (is worse than) the bound */
|
||||
/**
|
||||
* Check the sample
|
||||
*
|
||||
* @param sample a score
|
||||
* @return whether the sample exceeds (is worse than) the bound
|
||||
*/
|
||||
public abstract boolean check(float sample);
|
||||
|
||||
public static BoundsChecker create(boolean reversed) {
|
||||
/**
|
||||
* Create a min or max bound checker
|
||||
*
|
||||
* @param reversed true for the min and false for the max
|
||||
* @return the bound checker
|
||||
*/
|
||||
public static Lucene90BoundsChecker create(boolean reversed) {
|
||||
if (reversed) {
|
||||
return new Min();
|
||||
} else {
|
||||
|
@ -48,7 +61,7 @@ public abstract class BoundsChecker {
|
|||
* A helper class for an hnsw graph that serves as a comparator of the currently set maximum value
|
||||
* with a new value.
|
||||
*/
|
||||
public static class Max extends BoundsChecker {
|
||||
public static class Max extends Lucene90BoundsChecker {
|
||||
Max() {
|
||||
bound = Float.NEGATIVE_INFINITY;
|
||||
}
|
||||
|
@ -70,7 +83,7 @@ public abstract class BoundsChecker {
|
|||
* A helper class for an hnsw graph that serves as a comparator of the currently set minimum value
|
||||
* with a new value.
|
||||
*/
|
||||
public static class Min extends BoundsChecker {
|
||||
public static class Min extends Lucene90BoundsChecker {
|
||||
|
||||
Min() {
|
||||
bound = Float.POSITIVE_INFINITY;
|
|
@ -25,7 +25,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
|
|||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
|
@ -51,7 +50,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final BoundsChecker bound;
|
||||
private final Lucene90BoundsChecker bound;
|
||||
final Lucene90OnHeapHnswGraph hnsw;
|
||||
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
@ -91,7 +90,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
bound = Lucene90BoundsChecker.create(false);
|
||||
random = new SplittableRandom(seed);
|
||||
scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||
}
|
||||
|
@ -234,9 +233,9 @@ public final class Lucene90HnswGraphBuilder {
|
|||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float diversityCheck =
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node()[i]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
if (bound.check(neighborSimilarity) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -267,13 +266,14 @@ public final class Lucene90HnswGraphBuilder {
|
|||
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
||||
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
||||
// them, drop it
|
||||
int nbrNode = neighbors.node()[i];
|
||||
int neighborId = neighbors.node()[i];
|
||||
bound.set(neighbors.score()[i]);
|
||||
float[] nbrVector = vectorValues.vectorValue(nbrNode);
|
||||
float[] neighborVector = vectorValues.vectorValue(neighborId);
|
||||
for (int j = maxConn; j > i; j--) {
|
||||
float diversityCheck =
|
||||
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node()[j]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(
|
||||
neighborVector, buildVectors.vectorValue(neighbors.node()[j]));
|
||||
if (bound.check(neighborSimilarity) == false) {
|
||||
// node j is too similar to node i given its score relative to the base node
|
||||
// replace it with the new node, which is at [maxConn]
|
||||
return i;
|
||||
|
|
|
@ -266,9 +266,9 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
while (results.size() > 0) {
|
||||
int node = results.topNode();
|
||||
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
|
||||
float minSimilarity = results.topScore();
|
||||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], minSimilarity);
|
||||
}
|
||||
TotalHits.Relation relation =
|
||||
results.incomplete()
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
|
@ -85,9 +84,9 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
int size = graphValues.size();
|
||||
|
||||
// MIN heap, holding the top results
|
||||
NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
|
||||
NeighborQueue results = new NeighborQueue(numSeed, false);
|
||||
// MAX heap, from which to pull the candidate nodes
|
||||
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
|
||||
NeighborQueue candidates = new NeighborQueue(numSeed, true);
|
||||
|
||||
int numVisited = 0;
|
||||
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
|
||||
|
@ -114,13 +113,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
// Set the bound to the worst current result and below reject any newly-generated candidates
|
||||
// failing
|
||||
// to exceed this bound
|
||||
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
Lucene90BoundsChecker bound = Lucene90BoundsChecker.create(false);
|
||||
bound.set(results.topScore());
|
||||
while (candidates.size() > 0 && results.incomplete() == false) {
|
||||
// get the best candidate (closest or best scoring)
|
||||
float topCandidateScore = candidates.topScore();
|
||||
float topCandidateSimilarity = candidates.topScore();
|
||||
if (results.size() >= topK) {
|
||||
if (bound.check(topCandidateScore)) {
|
||||
if (bound.check(topCandidateSimilarity)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -138,11 +137,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
break;
|
||||
}
|
||||
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
if (results.size() < numSeed || bound.check(score) == false) {
|
||||
candidates.add(friendOrd, score);
|
||||
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
if (results.size() < numSeed || bound.check(friendSimilarity) == false) {
|
||||
candidates.add(friendOrd, friendSimilarity);
|
||||
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
|
||||
results.insertWithOverflow(friendOrd, score);
|
||||
results.insertWithOverflow(friendOrd, friendSimilarity);
|
||||
bound.set(results.topScore());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
/**
|
||||
* A helper class for an hnsw graph that serves as a comparator of the currently set bound value
|
||||
* with a new value.
|
||||
*/
|
||||
public abstract class Lucene91BoundsChecker {
|
||||
|
||||
float bound;
|
||||
|
||||
/** Default Constructor */
|
||||
public Lucene91BoundsChecker() {}
|
||||
|
||||
/** Update the bound if sample is better */
|
||||
public abstract void update(float sample);
|
||||
|
||||
/** Update the bound unconditionally */
|
||||
public void set(float sample) {
|
||||
bound = sample;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check the sample
|
||||
*
|
||||
* @param sample a score
|
||||
* @return whether the sample exceeds (is worse than) the bound
|
||||
*/
|
||||
public abstract boolean check(float sample);
|
||||
|
||||
/**
|
||||
* Create a min or max bound checker
|
||||
*
|
||||
* @param reversed true for the min and false for the max
|
||||
* @return the bound checker
|
||||
*/
|
||||
public static Lucene91BoundsChecker create(boolean reversed) {
|
||||
if (reversed) {
|
||||
return new Min();
|
||||
} else {
|
||||
return new Max();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class for an hnsw graph that serves as a comparator of the currently set maximum value
|
||||
* with a new value.
|
||||
*/
|
||||
public static class Max extends Lucene91BoundsChecker {
|
||||
Max() {
|
||||
bound = Float.NEGATIVE_INFINITY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(float sample) {
|
||||
if (sample > bound) {
|
||||
bound = sample;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean check(float sample) {
|
||||
return sample < bound;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class for an hnsw graph that serves as a comparator of the currently set minimum value
|
||||
* with a new value.
|
||||
*/
|
||||
public static class Min extends Lucene91BoundsChecker {
|
||||
|
||||
Min() {
|
||||
bound = Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(float sample) {
|
||||
if (sample < bound) {
|
||||
bound = sample;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean check(float sample) {
|
||||
return sample > bound;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,7 +28,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.hnsw.BoundsChecker;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
@ -55,7 +54,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final BoundsChecker bound;
|
||||
private final Lucene91BoundsChecker bound;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
|
||||
final Lucene91OnHeapHnswGraph hnsw;
|
||||
|
@ -104,9 +103,9 @@ public final class Lucene91HnswGraphBuilder {
|
|||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
similarityFunction,
|
||||
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
|
||||
new NeighborQueue(beamWidth, true),
|
||||
new FixedBitSet(vectorValues.size()));
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
bound = Lucene91BoundsChecker.create(false);
|
||||
scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1));
|
||||
}
|
||||
|
||||
|
@ -231,8 +230,8 @@ public final class Lucene91HnswGraphBuilder {
|
|||
// extract all the Neighbors from the queue into an array; these will now be
|
||||
// sorted from worst to best
|
||||
for (int i = 0; i < candidateCount; i++) {
|
||||
float score = candidates.topScore();
|
||||
scratch.add(candidates.pop(), score);
|
||||
float similarity = candidates.topScore();
|
||||
scratch.add(candidates.pop(), similarity);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,9 +252,9 @@ public final class Lucene91HnswGraphBuilder {
|
|||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float diversityCheck =
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
if (bound.check(neighborSimilarity) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -286,13 +285,13 @@ public final class Lucene91HnswGraphBuilder {
|
|||
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
||||
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
||||
// them, drop it
|
||||
int nbrNode = neighbors.node[i];
|
||||
int neighborId = neighbors.node[i];
|
||||
bound.set(neighbors.score[i]);
|
||||
float[] nbrVector = vectorValues.vectorValue(nbrNode);
|
||||
float[] neighborVector = vectorValues.vectorValue(neighborId);
|
||||
for (int j = maxConn; j > i; j--) {
|
||||
float diversityCheck =
|
||||
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(neighborVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||
if (bound.check(neighborSimilarity) == false) {
|
||||
// node j is too similar to node i given its score relative to the base node
|
||||
// replace it with the new node, which is at [maxConn]
|
||||
return i;
|
||||
|
|
|
@ -253,9 +253,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
while (results.size() > 0) {
|
||||
int node = results.topNode();
|
||||
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
|
||||
float minSimilarity = results.topScore();
|
||||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), score);
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), minSimilarity);
|
||||
}
|
||||
|
||||
TotalHits.Relation relation =
|
||||
|
|
|
@ -246,9 +246,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
while (results.size() > 0) {
|
||||
int node = results.topNode();
|
||||
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
|
||||
float minSimilarity = results.topScore();
|
||||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), minSimilarity);
|
||||
}
|
||||
|
||||
TotalHits.Relation relation =
|
||||
|
|
|
@ -170,7 +170,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
float[] vector = values.vectorValue();
|
||||
float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||
numVisited++;
|
||||
}
|
||||
|
|
|
@ -246,7 +246,7 @@ public final class Lucene93HnswVectorsReader extends KnnVectorsReader {
|
|||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
while (results.size() > 0) {
|
||||
int node = results.topNode();
|
||||
float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
|
||||
float score = results.topScore();
|
||||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
|
||||
}
|
||||
|
|
|
@ -26,15 +26,10 @@ import static org.apache.lucene.util.VectorUtil.*;
|
|||
public enum VectorSimilarityFunction {
|
||||
|
||||
/** Euclidean distance */
|
||||
EUCLIDEAN(true) {
|
||||
EUCLIDEAN {
|
||||
@Override
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return squareDistance(v1, v2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float convertToScore(float similarity) {
|
||||
return 1 / (1 + similarity);
|
||||
return 1 / (1 + squareDistance(v1, v2));
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -47,12 +42,7 @@ public enum VectorSimilarityFunction {
|
|||
DOT_PRODUCT {
|
||||
@Override
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return dotProduct(v1, v2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float convertToScore(float similarity) {
|
||||
return (1 + similarity) / 2;
|
||||
return (1 + dotProduct(v1, v2)) / 2;
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -60,50 +50,22 @@ public enum VectorSimilarityFunction {
|
|||
* Cosine similarity. NOTE: the preferred way to perform cosine similarity is to normalize all
|
||||
* vectors to unit length, and instead use {@link VectorSimilarityFunction#DOT_PRODUCT}. You
|
||||
* should only use this function if you need to preserve the original vectors and cannot normalize
|
||||
* them in advance.
|
||||
* them in advance. The similarity score is normalised to assure it is positive.
|
||||
*/
|
||||
COSINE {
|
||||
@Override
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return cosine(v1, v2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float convertToScore(float similarity) {
|
||||
return (1 + similarity) / 2;
|
||||
return (1 + cosine(v1, v2)) / 2;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* If true, the scores associated with vector comparisons are nonnegative and in reverse order;
|
||||
* that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
|
||||
* represent more similar vectors, and scores may be negative or positive.
|
||||
*/
|
||||
public final boolean reversed;
|
||||
|
||||
VectorSimilarityFunction(boolean reversed) {
|
||||
this.reversed = reversed;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction() {
|
||||
reversed = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates a similarity score between the two vectors with a specified function.
|
||||
* Calculates a similarity score between the two vectors with a specified function. Higher
|
||||
* similarity scores correspond to closer vectors.
|
||||
*
|
||||
* @param v1 a vector
|
||||
* @param v2 another vector, of the same dimension
|
||||
* @return the value of the similarity function applied to the two vectors
|
||||
*/
|
||||
public abstract float compare(float[] v1, float[] v2);
|
||||
|
||||
/**
|
||||
* Converts similarity scores used (may be negative, reversed, etc) into document scores, which
|
||||
* must be positive, with higher scores representing better matches.
|
||||
*
|
||||
* @param similarity the raw internal score as returned by {@link #compare(float[], float[])}.
|
||||
* @return normalizedSimilarity
|
||||
*/
|
||||
public abstract float convertToScore(float similarity);
|
||||
}
|
||||
|
|
|
@ -197,7 +197,7 @@ public class KnnVectorQuery extends Query {
|
|||
assert vectorDoc == doc;
|
||||
float[] vector = vectorValues.vectorValue();
|
||||
|
||||
float score = similarityFunction.convertToScore(similarityFunction.compare(vector, target));
|
||||
float score = similarityFunction.compare(vector, target);
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
|
|
|
@ -51,7 +51,6 @@ public final class HnswGraphBuilder {
|
|||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final BoundsChecker bound;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
|
||||
final OnHeapHnswGraph hnsw;
|
||||
|
@ -96,15 +95,14 @@ public final class HnswGraphBuilder {
|
|||
this.ml = 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode, similarityFunction.reversed);
|
||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
similarityFunction,
|
||||
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
|
||||
new NeighborQueue(beamWidth, true),
|
||||
new FixedBitSet(vectorValues.size()));
|
||||
bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
// in scratch we store candidates in reverse order: worse candidates are first
|
||||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), similarityFunction.reversed);
|
||||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -225,8 +223,8 @@ public final class HnswGraphBuilder {
|
|||
// extract all the Neighbors from the queue into an array; these will now be
|
||||
// sorted from worst to best
|
||||
for (int i = 0; i < candidateCount; i++) {
|
||||
float score = candidates.topScore();
|
||||
scratch.add(candidates.pop(), score);
|
||||
float maxSimilarity = candidates.topScore();
|
||||
scratch.add(candidates.pop(), maxSimilarity);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -245,11 +243,10 @@ public final class HnswGraphBuilder {
|
|||
NeighborArray neighbors,
|
||||
RandomAccessVectorValues vectorValues)
|
||||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float diversityCheck =
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -261,16 +258,17 @@ public final class HnswGraphBuilder {
|
|||
* neighbours
|
||||
*/
|
||||
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
||||
float minAcceptedSimilarity;
|
||||
for (int i = neighbors.size() - 1; i > 0; i--) {
|
||||
int cNode = neighbors.node[i];
|
||||
float[] cVector = vectorValues.vectorValue(cNode);
|
||||
bound.set(neighbors.score[i]);
|
||||
minAcceptedSimilarity = neighbors.score[i];
|
||||
// check the candidate against its better-scoring neighbors
|
||||
for (int j = i - 1; j >= 0; j--) {
|
||||
float diversityCheck =
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ public final class HnswGraphSearcher {
|
|||
HnswGraphSearcher graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
similarityFunction,
|
||||
new NeighborQueue(topK, similarityFunction.reversed == false),
|
||||
new NeighborQueue(topK, true),
|
||||
new SparseFixedBitSet(vectors.size()));
|
||||
NeighborQueue results;
|
||||
int[] eps = new int[] {graph.entryNode()};
|
||||
|
@ -139,7 +139,7 @@ public final class HnswGraphSearcher {
|
|||
int visitedLimit)
|
||||
throws IOException {
|
||||
int size = graph.size();
|
||||
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
NeighborQueue results = new NeighborQueue(topK, false);
|
||||
clearScratchState();
|
||||
|
||||
int numVisited = 0;
|
||||
|
@ -160,14 +160,14 @@ public final class HnswGraphSearcher {
|
|||
|
||||
// A bound that holds the minimum similarity to the query vector that a candidate vector must
|
||||
// have to be considered.
|
||||
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
|
||||
if (results.size() >= topK) {
|
||||
bound.set(results.topScore());
|
||||
minAcceptedSimilarity = results.topScore();
|
||||
}
|
||||
while (candidates.size() > 0 && results.incomplete() == false) {
|
||||
// get the best candidate (closest or best scoring)
|
||||
float topCandidateScore = candidates.topScore();
|
||||
if (bound.check(topCandidateScore)) {
|
||||
float topCandidateSimilarity = candidates.topScore();
|
||||
if (topCandidateSimilarity < minAcceptedSimilarity) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -184,13 +184,13 @@ public final class HnswGraphSearcher {
|
|||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
numVisited++;
|
||||
if (bound.check(score) == false) {
|
||||
candidates.add(friendOrd, score);
|
||||
if (friendSimilarity >= minAcceptedSimilarity) {
|
||||
candidates.add(friendOrd, friendSimilarity);
|
||||
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
|
||||
if (results.insertWithOverflow(friendOrd, score) && results.size() >= topK) {
|
||||
bound.set(results.topScore());
|
||||
if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
|
||||
minAcceptedSimilarity = results.topScore();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,13 +30,13 @@ import org.apache.lucene.util.NumericUtils;
|
|||
public class NeighborQueue {
|
||||
|
||||
private enum Order {
|
||||
NATURAL {
|
||||
MIN_HEAP {
|
||||
@Override
|
||||
long apply(long v) {
|
||||
return v;
|
||||
}
|
||||
},
|
||||
REVERSED {
|
||||
MAX_HEAP {
|
||||
@Override
|
||||
long apply(long v) {
|
||||
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
|
||||
|
@ -56,9 +56,9 @@ public class NeighborQueue {
|
|||
// Whether the search stopped early because it reached the visited nodes limit
|
||||
private boolean incomplete;
|
||||
|
||||
public NeighborQueue(int initialSize, boolean reversed) {
|
||||
public NeighborQueue(int initialSize, boolean maxHeap) {
|
||||
this.heap = new LongHeap(initialSize);
|
||||
this.order = reversed ? Order.REVERSED : Order.NATURAL;
|
||||
this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP;
|
||||
}
|
||||
|
||||
/** @return the number of elements in the heap */
|
||||
|
@ -89,32 +89,66 @@ public class NeighborQueue {
|
|||
return heap.insertWithOverflow(encode(newNode, newScore));
|
||||
}
|
||||
|
||||
/**
|
||||
* Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule
|
||||
* that when two scores are equals, the smaller node ID must win.
|
||||
*
|
||||
* <p>The most significant 32 bits represent the float score, encoded as a sortable int.
|
||||
*
|
||||
* <p>The less significant 32 bits represent the node ID.
|
||||
*
|
||||
* <p>The bits representing the node ID are complemented to guarantee the win for the smaller node
|
||||
* Id.
|
||||
*
|
||||
* <p>The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that
|
||||
* has
|
||||
*
|
||||
* <p>The most significant 32 bits to 0
|
||||
*
|
||||
* <p>The less significant 32 bits represent the node ID.
|
||||
*
|
||||
* @param node the node ID
|
||||
* @param score the node score
|
||||
* @return the encoded score, node ID
|
||||
*/
|
||||
private long encode(int node, float score) {
|
||||
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
|
||||
return order.apply(
|
||||
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
|
||||
}
|
||||
|
||||
private float decodeScore(long heapValue) {
|
||||
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
|
||||
}
|
||||
|
||||
private int decodeNodeId(long heapValue) {
|
||||
return (int) ~(order.apply(heapValue));
|
||||
}
|
||||
|
||||
/** Removes the top element and returns its node id. */
|
||||
public int pop() {
|
||||
return (int) order.apply(heap.pop());
|
||||
return decodeNodeId(heap.pop());
|
||||
}
|
||||
|
||||
public int[] nodes() {
|
||||
int size = size();
|
||||
int[] nodes = new int[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
nodes[i] = (int) order.apply(heap.get(i + 1));
|
||||
nodes[i] = decodeNodeId(heap.get(i + 1));
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
/** Returns the top element's node id. */
|
||||
public int topNode() {
|
||||
return (int) order.apply(heap.top());
|
||||
return decodeNodeId(heap.top());
|
||||
}
|
||||
|
||||
/** Returns the top element's node score. */
|
||||
/**
|
||||
* Returns the top element's node score. For the min heap this is the minimum score. For the max
|
||||
* heap this is the maximum score.
|
||||
*/
|
||||
public float topScore() {
|
||||
return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
|
||||
return decodeScore(heap.top());
|
||||
}
|
||||
|
||||
public void clear() {
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.util.ArrayUtil;
|
|||
*/
|
||||
public final class OnHeapHnswGraph extends HnswGraph {
|
||||
|
||||
private final boolean similarityReversed;
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
|
||||
|
@ -52,8 +51,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
|||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
OnHeapHnswGraph(int M, int levelOfFirstNode, boolean similarityReversed) {
|
||||
this.similarityReversed = similarityReversed;
|
||||
OnHeapHnswGraph(int M, int levelOfFirstNode) {
|
||||
this.numLevels = levelOfFirstNode + 1;
|
||||
this.graph = new ArrayList<>(numLevels);
|
||||
this.entryNode = 0;
|
||||
|
@ -63,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
|||
this.nsize0 = (M * 2 + 1);
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
graph.add(new ArrayList<>());
|
||||
graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, similarityReversed == false));
|
||||
graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true));
|
||||
}
|
||||
|
||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||
|
@ -123,9 +121,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
|
|||
}
|
||||
}
|
||||
}
|
||||
graph
|
||||
.get(level)
|
||||
.add(new NeighborArray(level == 0 ? nsize0 : nsize, similarityReversed == false));
|
||||
graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -273,8 +273,7 @@ public class KnnGraphTester {
|
|||
for (int i = 0; i < hnsw.size(); i++) {
|
||||
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||
System.out.printf(Locale.ROOT, "%5d", i);
|
||||
NeighborArray sorted =
|
||||
new NeighborArray(neighbors.size(), similarityFunction.reversed == false);
|
||||
NeighborArray sorted = new NeighborArray(neighbors.size(), true);
|
||||
for (int j = 0; j < neighbors.size(); j++) {
|
||||
int node = neighbors.node[j];
|
||||
float score = neighbors.score[j];
|
||||
|
@ -555,7 +554,7 @@ public class KnnGraphTester {
|
|||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
offset += blockSize;
|
||||
NeighborQueue queue = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
NeighborQueue queue = new NeighborQueue(topK, false);
|
||||
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||
vectors.get(vector);
|
||||
float d = similarityFunction.compare(query, vector);
|
||||
|
|
|
@ -309,30 +309,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertTrue(nn.visitedCount() <= visitedLimit);
|
||||
}
|
||||
|
||||
public void testBoundsCheckerMax() {
|
||||
BoundsChecker max = BoundsChecker.create(false);
|
||||
float f = random().nextFloat() - 0.5f;
|
||||
// any float > -MAX_VALUE is in bounds
|
||||
assertFalse(max.check(f));
|
||||
// f is now the bound (minus some delta)
|
||||
max.update(f);
|
||||
assertFalse(max.check(f)); // f is not out of bounds
|
||||
assertFalse(max.check(f + 1)); // anything greater than f is in bounds
|
||||
assertTrue(max.check(f - 1e-5f)); // delta is zero initially
|
||||
}
|
||||
|
||||
public void testBoundsCheckerMin() {
|
||||
BoundsChecker min = BoundsChecker.create(true);
|
||||
float f = random().nextFloat() - 0.5f;
|
||||
// any float < MAX_VALUE is in bounds
|
||||
assertFalse(min.check(f));
|
||||
// f is now the bound (minus some delta)
|
||||
min.update(f);
|
||||
assertFalse(min.check(f)); // f is not out of bounds
|
||||
assertFalse(min.check(f - 1)); // anything less than f is in bounds
|
||||
assertTrue(min.check(f + 1e-5f)); // delta is zero initially
|
||||
}
|
||||
|
||||
public void testHnswGraphBuilderInvalid() {
|
||||
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
|
||||
expectThrows(
|
||||
|
@ -441,7 +417,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
while (actual.size() > topK) {
|
||||
actual.pop();
|
||||
}
|
||||
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
NeighborQueue expected = new NeighborQueue(topK, false);
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||
|
|
Loading…
Reference in New Issue