LUCENE-10382: Ensure kNN filtering works with other codecs (#700)

The original PR that added kNN filtering support overlooked non-default codecs.
This follow-up ensures that other codecs work with the new filtering logic:
* Make sure to check the visited nodes limit in `SimpleTextKnnVectorsReader`
and `Lucene90HnswVectorsReader`
* Add a test `BaseKnnVectorsFormatTestCase` to cover this case
* Fix failures in `TestKnnVectorQuery#testRandomWithFilter`, whose assumptions
don't hold when SimpleText is used

This PR also clarifies the limit checking logic for
`Lucene91HnswVectorsReader`. Now we always check the limit before visiting a
new node, whereas before we only checked it in an outer loop.
This commit is contained in:
Julie Tibshirani 2022-02-23 14:58:27 -08:00 committed by GitHub
parent 4364bdd63e
commit b40a750aa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 125 additions and 21 deletions

View File

@ -144,7 +144,15 @@ public final class Lucene90HnswGraphBuilder {
// We pass 'null' for acceptOrds because there are no deletions while building the graph // We pass 'null' for acceptOrds because there are no deletions while building the graph
NeighborQueue candidates = NeighborQueue candidates =
Lucene90OnHeapHnswGraph.search( Lucene90OnHeapHnswGraph.search(
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random); value,
beamWidth,
beamWidth,
vectorValues,
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE,
random);
int node = hnsw.addNode(); int node = hnsw.addNode();

View File

@ -252,6 +252,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
getGraphValues(fieldEntry), getGraphValues(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry), getAcceptOrds(acceptDocs, fieldEntry),
visitedLimit,
random); random);
int i = 0; int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
@ -261,11 +262,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
results.pop(); results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score); scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
} }
// always return >= the case where we can assert == is only when there are fewer than topK TotalHits.Relation relation =
// vectors in the index results.incomplete()
return new TopDocs( ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), : TotalHits.Relation.EQUAL_TO;
scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException { private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {

View File

@ -80,6 +80,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
HnswGraph graphValues, HnswGraph graphValues,
Bits acceptOrds, Bits acceptOrds,
int visitedLimit,
SplittableRandom random) SplittableRandom random)
throws IOException { throws IOException {
int size = graphValues.size(); int size = graphValues.size();
@ -89,6 +90,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
// MAX heap, from which to pull the candidate nodes // MAX heap, from which to pull the candidate nodes
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed); NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
int numVisited = 0;
// set of ordinals that have been visited by search on this layer, used to avoid backtracking // set of ordinals that have been visited by search on this layer, used to avoid backtracking
SparseFixedBitSet visited = new SparseFixedBitSet(size); SparseFixedBitSet visited = new SparseFixedBitSet(size);
// get initial candidates at random // get initial candidates at random
@ -96,12 +98,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
for (int i = 0; i < boundedNumSeed; i++) { for (int i = 0; i < boundedNumSeed; i++) {
int entryPoint = random.nextInt(size); int entryPoint = random.nextInt(size);
if (visited.getAndSet(entryPoint) == false) { if (visited.getAndSet(entryPoint) == false) {
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
// explore the topK starting points of some random numSeed probes // explore the topK starting points of some random numSeed probes
float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint)); float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
candidates.add(entryPoint, score); candidates.add(entryPoint, score);
if (acceptOrds == null || acceptOrds.get(entryPoint)) { if (acceptOrds == null || acceptOrds.get(entryPoint)) {
results.add(entryPoint, score); results.add(entryPoint, score);
} }
numVisited++;
} }
} }
@ -110,7 +117,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
// to exceed this bound // to exceed this bound
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed); BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
bound.set(results.topScore()); bound.set(results.topScore());
while (candidates.size() > 0) { while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring) // get the best candidate (closest or best scoring)
float topCandidateScore = candidates.topScore(); float topCandidateScore = candidates.topScore();
if (results.size() >= topK) { if (results.size() >= topK) {
@ -127,6 +134,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
continue; continue;
} }
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd)); float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
if (results.size() < numSeed || bound.check(score) == false) { if (results.size() < numSeed || bound.check(score) == false) {
candidates.add(friendOrd, score); candidates.add(friendOrd, score);
@ -135,12 +147,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
bound.set(results.topScore()); bound.set(results.topScore());
} }
} }
numVisited++;
} }
} }
while (results.size() > topK) { while (results.size() > topK) {
results.pop(); results.pop();
} }
results.setVisitedCount(visited.approximateCardinality()); results.setVisitedCount(numVisited);
return results; return results;
} }

View File

@ -154,20 +154,31 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
FieldInfo info = readState.fieldInfos.fieldInfo(field); FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
HitQueue topK = new HitQueue(k, false); HitQueue topK = new HitQueue(k, false);
int numVisited = 0;
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
int doc; int doc;
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (acceptDocs != null && acceptDocs.get(doc) == false) { if (acceptDocs != null && acceptDocs.get(doc) == false) {
continue; continue;
} }
if (numVisited >= visitedLimit) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
float[] vector = values.vectorValue(); float[] vector = values.vectorValue();
float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target)); float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
topK.insertWithOverflow(new ScoreDoc(doc, score)); topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++;
} }
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()]; ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) { for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = topK.pop(); topScoreDocs[i] = topK.pop();
} }
return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs); return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
} }
@Override @Override

View File

@ -140,12 +140,16 @@ public final class HnswGraphSearcher {
int numVisited = 0; int numVisited = 0;
for (int ep : eps) { for (int ep : eps) {
if (visited.getAndSet(ep) == false) { if (visited.getAndSet(ep) == false) {
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(ep)); float score = similarityFunction.compare(query, vectors.vectorValue(ep));
numVisited++;
candidates.add(ep, score); candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) { if (acceptOrds == null || acceptOrds.get(ep)) {
results.add(ep, score); results.add(ep, score);
} }
numVisited++;
} }
} }
@ -155,18 +159,13 @@ public final class HnswGraphSearcher {
if (results.size() >= topK) { if (results.size() >= topK) {
bound.set(results.topScore()); bound.set(results.topScore());
} }
while (candidates.size() > 0) { while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring) // get the best candidate (closest or best scoring)
float topCandidateScore = candidates.topScore(); float topCandidateScore = candidates.topScore();
if (bound.check(topCandidateScore)) { if (bound.check(topCandidateScore)) {
break; break;
} }
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
int topCandidateNode = candidates.pop(); int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode); graph.seek(level, topCandidateNode);
int friendOrd; int friendOrd;
@ -176,6 +175,10 @@ public final class HnswGraphSearcher {
continue; continue;
} }
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd)); float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
numVisited++; numVisited++;
if (bound.check(score) == false) { if (bound.check(score) == false) {

View File

@ -44,6 +44,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
@ -491,7 +492,11 @@ public class TestKnnVectorQuery extends LuceneTestCase {
int dimension = atLeast(5); int dimension = atLeast(5);
int numIters = atLeast(10); int numIters = atLeast(10);
try (Directory d = newDirectory()) { try (Directory d = newDirectory()) {
RandomIndexWriter w = new RandomIndexWriter(random(), d); // Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) { for (int i = 0; i < numDocs; i++) {
Document doc = new Document(); Document doc = new Document();
doc.add(new KnnVectorField("field", randomVector(dimension))); doc.add(new KnnVectorField("field", randomVector(dimension)));
@ -690,8 +695,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
} }
@Override @Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
throws IOException {
throw new UnsupportedOperationException("exact search is not supported"); throw new UnsupportedOperationException("exact search is not supported");
} }
} }

View File

@ -309,8 +309,8 @@ public class TestHnswGraph extends LuceneTestCase {
createRandomAcceptOrds(0, vectors.size), createRandomAcceptOrds(0, vectors.size),
visitedLimit); visitedLimit);
assertTrue(nn.incomplete()); assertTrue(nn.incomplete());
// The visited count shouldn't be much over the limit // The visited count shouldn't exceed the limit
assertTrue(nn.visitedCount() < visitedLimit + 3); assertTrue(nn.visitedCount() <= visitedLimit);
} }
public void testBoundsCheckerMax() { public void testBoundsCheckerMax() {

View File

@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.Set; import java.util.Set;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
@ -45,6 +46,7 @@ import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
@ -821,6 +823,68 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
} }
} }
/**
* Tests whether {@link KnnVectorsReader#search} implementations obey the limit on the number of
* visited vectors. This test is a best-effort attempt to capture the right behavior, and isn't
* meant to define a strict requirement on behavior.
*/
public void testSearchWithVisitedLimit() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(300);
int dimension = atLeast(10);
for (int i = 0; i < numDoc; i++) {
float[] value;
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
value = randomVector(dimension);
} else {
value = null;
}
add(iw, fieldName, i, value, VectorSimilarityFunction.EUCLIDEAN);
}
iw.forceMerge(1);
// randomly delete some documents
for (int i = 0; i < 30; i++) {
int idToDelete = random().nextInt(numDoc);
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
}
try (IndexReader reader = DirectoryReader.open(iw)) {
for (LeafReaderContext ctx : reader.leaves()) {
Bits liveDocs = ctx.reader().getLiveDocs();
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
// check the limit is hit when it's very small
int k = 5 + random().nextInt(45);
int visitedLimit = k + random().nextInt(5);
TopDocs results =
ctx.reader()
.searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
assertEquals(visitedLimit, results.totalHits.value);
// check the limit is not hit when it clearly exceeds the number of vectors
k = vectorValues.size();
visitedLimit = k + 30;
results =
ctx.reader()
.searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
assertTrue(results.totalHits.value <= visitedLimit);
}
}
}
}
/** /**
* Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes * Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes
* merging, sometimes sorting the index, using an HNSW similarity function so as to also produce a * merging, sometimes sorting the index, using an HNSW similarity function so as to also produce a