mirror of https://github.com/apache/lucene.git
LUCENE-10382: Ensure kNN filtering works with other codecs (#700)
The original PR that added kNN filtering support overlooked non-default codecs. This follow-up ensures that other codecs work with the new filtering logic: * Make sure to check the visited nodes limit in `SimpleTextKnnVectorsReader` and `Lucene90HnswVectorsReader` * Add a test `BaseKnnVectorsFormatTestCase` to cover this case * Fix failures in `TestKnnVectorQuery#testRandomWithFilter`, whose assumptions don't hold when SimpleText is used This PR also clarifies the limit checking logic for `Lucene91HnswVectorsReader`. Now we always check the limit before visiting a new node, whereas before we only checked it in an outer loop.
This commit is contained in:
parent
4364bdd63e
commit
b40a750aa8
|
@ -144,7 +144,15 @@ public final class Lucene90HnswGraphBuilder {
|
|||
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
||||
NeighborQueue candidates =
|
||||
Lucene90OnHeapHnswGraph.search(
|
||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
||||
value,
|
||||
beamWidth,
|
||||
beamWidth,
|
||||
vectorValues,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
Integer.MAX_VALUE,
|
||||
random);
|
||||
|
||||
int node = hnsw.addNode();
|
||||
|
||||
|
|
|
@ -252,6 +252,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
fieldEntry.similarityFunction,
|
||||
getGraphValues(fieldEntry),
|
||||
getAcceptOrds(acceptDocs, fieldEntry),
|
||||
visitedLimit,
|
||||
random);
|
||||
int i = 0;
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
|
@ -261,11 +262,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
results.pop();
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
|
||||
}
|
||||
// always return >= the case where we can assert == is only when there are fewer than topK
|
||||
// vectors in the index
|
||||
return new TopDocs(
|
||||
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
|
||||
scoreDocs);
|
||||
TotalHits.Relation relation =
|
||||
results.incomplete()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
|
|
|
@ -80,6 +80,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graphValues,
|
||||
Bits acceptOrds,
|
||||
int visitedLimit,
|
||||
SplittableRandom random)
|
||||
throws IOException {
|
||||
int size = graphValues.size();
|
||||
|
@ -89,6 +90,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
// MAX heap, from which to pull the candidate nodes
|
||||
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
|
||||
|
||||
int numVisited = 0;
|
||||
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
|
||||
SparseFixedBitSet visited = new SparseFixedBitSet(size);
|
||||
// get initial candidates at random
|
||||
|
@ -96,12 +98,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
for (int i = 0; i < boundedNumSeed; i++) {
|
||||
int entryPoint = random.nextInt(size);
|
||||
if (visited.getAndSet(entryPoint) == false) {
|
||||
if (numVisited >= visitedLimit) {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
// explore the topK starting points of some random numSeed probes
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
|
||||
candidates.add(entryPoint, score);
|
||||
if (acceptOrds == null || acceptOrds.get(entryPoint)) {
|
||||
results.add(entryPoint, score);
|
||||
}
|
||||
numVisited++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +117,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
// to exceed this bound
|
||||
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
|
||||
bound.set(results.topScore());
|
||||
while (candidates.size() > 0) {
|
||||
while (candidates.size() > 0 && results.incomplete() == false) {
|
||||
// get the best candidate (closest or best scoring)
|
||||
float topCandidateScore = candidates.topScore();
|
||||
if (results.size() >= topK) {
|
||||
|
@ -127,6 +134,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (numVisited >= visitedLimit) {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
if (results.size() < numSeed || bound.check(score) == false) {
|
||||
candidates.add(friendOrd, score);
|
||||
|
@ -135,12 +147,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
bound.set(results.topScore());
|
||||
}
|
||||
}
|
||||
numVisited++;
|
||||
}
|
||||
}
|
||||
while (results.size() > topK) {
|
||||
results.pop();
|
||||
}
|
||||
results.setVisitedCount(visited.approximateCardinality());
|
||||
results.setVisitedCount(numVisited);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
|
|
@ -154,20 +154,31 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||
HitQueue topK = new HitQueue(k, false);
|
||||
|
||||
int numVisited = 0;
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
|
||||
int doc;
|
||||
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (numVisited >= visitedLimit) {
|
||||
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||
break;
|
||||
}
|
||||
|
||||
float[] vector = values.vectorValue();
|
||||
float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
|
||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||
numVisited++;
|
||||
}
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = topK.pop();
|
||||
}
|
||||
return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs);
|
||||
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -140,12 +140,16 @@ public final class HnswGraphSearcher {
|
|||
int numVisited = 0;
|
||||
for (int ep : eps) {
|
||||
if (visited.getAndSet(ep) == false) {
|
||||
if (numVisited >= visitedLimit) {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
|
||||
numVisited++;
|
||||
candidates.add(ep, score);
|
||||
if (acceptOrds == null || acceptOrds.get(ep)) {
|
||||
results.add(ep, score);
|
||||
}
|
||||
numVisited++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,18 +159,13 @@ public final class HnswGraphSearcher {
|
|||
if (results.size() >= topK) {
|
||||
bound.set(results.topScore());
|
||||
}
|
||||
while (candidates.size() > 0) {
|
||||
while (candidates.size() > 0 && results.incomplete() == false) {
|
||||
// get the best candidate (closest or best scoring)
|
||||
float topCandidateScore = candidates.topScore();
|
||||
if (bound.check(topCandidateScore)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (numVisited >= visitedLimit) {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
|
||||
int topCandidateNode = candidates.pop();
|
||||
graph.seek(level, topCandidateNode);
|
||||
int friendOrd;
|
||||
|
@ -176,6 +175,10 @@ public final class HnswGraphSearcher {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (numVisited >= visitedLimit) {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
numVisited++;
|
||||
if (bound.check(score) == false) {
|
||||
|
|
|
@ -44,6 +44,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
|
@ -491,7 +492,11 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
try (Directory d = newDirectory()) {
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d);
|
||||
// Always use the default kNN format to have predictable behavior around when it hits
|
||||
// visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
|
||||
// implementation.
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", randomVector(dimension)));
|
||||
|
@ -690,8 +695,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
throws IOException {
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
|
||||
throw new UnsupportedOperationException("exact search is not supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -309,8 +309,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
createRandomAcceptOrds(0, vectors.size),
|
||||
visitedLimit);
|
||||
assertTrue(nn.incomplete());
|
||||
// The visited count shouldn't be much over the limit
|
||||
assertTrue(nn.visitedCount() < visitedLimit + 3);
|
||||
// The visited count shouldn't exceed the limit
|
||||
assertTrue(nn.visitedCount() <= visitedLimit);
|
||||
}
|
||||
|
||||
public void testBoundsCheckerMax() {
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Arrays;
|
|||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
|
@ -45,6 +46,7 @@ import org.apache.lucene.index.VectorValues;
|
|||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
@ -821,6 +823,68 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests whether {@link KnnVectorsReader#search} implementations obey the limit on the number of
|
||||
* visited vectors. This test is a best-effort attempt to capture the right behavior, and isn't
|
||||
* meant to define a strict requirement on behavior.
|
||||
*/
|
||||
public void testSearchWithVisitedLimit() throws Exception {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
String fieldName = "field";
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(300);
|
||||
int dimension = atLeast(10);
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
float[] value;
|
||||
if (random().nextInt(7) != 3) {
|
||||
// usually index a vector value for a doc
|
||||
value = randomVector(dimension);
|
||||
} else {
|
||||
value = null;
|
||||
}
|
||||
add(iw, fieldName, i, value, VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
iw.forceMerge(1);
|
||||
|
||||
// randomly delete some documents
|
||||
for (int i = 0; i < 30; i++) {
|
||||
int idToDelete = random().nextInt(numDoc);
|
||||
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
|
||||
}
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check the limit is hit when it's very small
|
||||
int k = 5 + random().nextInt(45);
|
||||
int visitedLimit = k + random().nextInt(5);
|
||||
TopDocs results =
|
||||
ctx.reader()
|
||||
.searchNearestVectors(
|
||||
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
|
||||
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
|
||||
assertEquals(visitedLimit, results.totalHits.value);
|
||||
|
||||
// check the limit is not hit when it clearly exceeds the number of vectors
|
||||
k = vectorValues.size();
|
||||
visitedLimit = k + 30;
|
||||
results =
|
||||
ctx.reader()
|
||||
.searchNearestVectors(
|
||||
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
|
||||
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
|
||||
assertTrue(results.totalHits.value <= visitedLimit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes
|
||||
* merging, sometimes sorting the index, using an HNSW similarity function so as to also produce a
|
||||
|
|
Loading…
Reference in New Issue