mirror of https://github.com/apache/lucene.git
LUCENE-10382: Ensure kNN filtering works with other codecs (#700)
The original PR that added kNN filtering support overlooked non-default codecs. This follow-up ensures that other codecs work with the new filtering logic: * Make sure to check the visited nodes limit in `SimpleTextKnnVectorsReader` and `Lucene90HnswVectorsReader` * Add a test `BaseKnnVectorsFormatTestCase` to cover this case * Fix failures in `TestKnnVectorQuery#testRandomWithFilter`, whose assumptions don't hold when SimpleText is used This PR also clarifies the limit checking logic for `Lucene91HnswVectorsReader`. Now we always check the limit before visiting a new node, whereas before we only checked it in an outer loop.
This commit is contained in:
parent
4364bdd63e
commit
b40a750aa8
|
@ -144,7 +144,15 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
||||||
NeighborQueue candidates =
|
NeighborQueue candidates =
|
||||||
Lucene90OnHeapHnswGraph.search(
|
Lucene90OnHeapHnswGraph.search(
|
||||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
value,
|
||||||
|
beamWidth,
|
||||||
|
beamWidth,
|
||||||
|
vectorValues,
|
||||||
|
similarityFunction,
|
||||||
|
hnsw,
|
||||||
|
null,
|
||||||
|
Integer.MAX_VALUE,
|
||||||
|
random);
|
||||||
|
|
||||||
int node = hnsw.addNode();
|
int node = hnsw.addNode();
|
||||||
|
|
||||||
|
|
|
@ -252,6 +252,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
getGraphValues(fieldEntry),
|
getGraphValues(fieldEntry),
|
||||||
getAcceptOrds(acceptDocs, fieldEntry),
|
getAcceptOrds(acceptDocs, fieldEntry),
|
||||||
|
visitedLimit,
|
||||||
random);
|
random);
|
||||||
int i = 0;
|
int i = 0;
|
||||||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||||
|
@ -261,11 +262,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
results.pop();
|
results.pop();
|
||||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
|
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
|
||||||
}
|
}
|
||||||
// always return >= the case where we can assert == is only when there are fewer than topK
|
TotalHits.Relation relation =
|
||||||
// vectors in the index
|
results.incomplete()
|
||||||
return new TopDocs(
|
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||||
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
|
: TotalHits.Relation.EQUAL_TO;
|
||||||
scoreDocs);
|
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||||
|
|
|
@ -80,6 +80,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
HnswGraph graphValues,
|
HnswGraph graphValues,
|
||||||
Bits acceptOrds,
|
Bits acceptOrds,
|
||||||
|
int visitedLimit,
|
||||||
SplittableRandom random)
|
SplittableRandom random)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int size = graphValues.size();
|
int size = graphValues.size();
|
||||||
|
@ -89,6 +90,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
// MAX heap, from which to pull the candidate nodes
|
// MAX heap, from which to pull the candidate nodes
|
||||||
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
|
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
|
||||||
|
|
||||||
|
int numVisited = 0;
|
||||||
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
|
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
|
||||||
SparseFixedBitSet visited = new SparseFixedBitSet(size);
|
SparseFixedBitSet visited = new SparseFixedBitSet(size);
|
||||||
// get initial candidates at random
|
// get initial candidates at random
|
||||||
|
@ -96,12 +98,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
for (int i = 0; i < boundedNumSeed; i++) {
|
for (int i = 0; i < boundedNumSeed; i++) {
|
||||||
int entryPoint = random.nextInt(size);
|
int entryPoint = random.nextInt(size);
|
||||||
if (visited.getAndSet(entryPoint) == false) {
|
if (visited.getAndSet(entryPoint) == false) {
|
||||||
|
if (numVisited >= visitedLimit) {
|
||||||
|
results.markIncomplete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
// explore the topK starting points of some random numSeed probes
|
// explore the topK starting points of some random numSeed probes
|
||||||
float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
|
float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
|
||||||
candidates.add(entryPoint, score);
|
candidates.add(entryPoint, score);
|
||||||
if (acceptOrds == null || acceptOrds.get(entryPoint)) {
|
if (acceptOrds == null || acceptOrds.get(entryPoint)) {
|
||||||
results.add(entryPoint, score);
|
results.add(entryPoint, score);
|
||||||
}
|
}
|
||||||
|
numVisited++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +117,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
// to exceed this bound
|
// to exceed this bound
|
||||||
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
|
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
|
||||||
bound.set(results.topScore());
|
bound.set(results.topScore());
|
||||||
while (candidates.size() > 0) {
|
while (candidates.size() > 0 && results.incomplete() == false) {
|
||||||
// get the best candidate (closest or best scoring)
|
// get the best candidate (closest or best scoring)
|
||||||
float topCandidateScore = candidates.topScore();
|
float topCandidateScore = candidates.topScore();
|
||||||
if (results.size() >= topK) {
|
if (results.size() >= topK) {
|
||||||
|
@ -127,6 +134,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (numVisited >= visitedLimit) {
|
||||||
|
results.markIncomplete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||||
if (results.size() < numSeed || bound.check(score) == false) {
|
if (results.size() < numSeed || bound.check(score) == false) {
|
||||||
candidates.add(friendOrd, score);
|
candidates.add(friendOrd, score);
|
||||||
|
@ -135,12 +147,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
bound.set(results.topScore());
|
bound.set(results.topScore());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
numVisited++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (results.size() > topK) {
|
while (results.size() > topK) {
|
||||||
results.pop();
|
results.pop();
|
||||||
}
|
}
|
||||||
results.setVisitedCount(visited.approximateCardinality());
|
results.setVisitedCount(numVisited);
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -154,20 +154,31 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||||
HitQueue topK = new HitQueue(k, false);
|
HitQueue topK = new HitQueue(k, false);
|
||||||
|
|
||||||
|
int numVisited = 0;
|
||||||
|
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||||
|
|
||||||
int doc;
|
int doc;
|
||||||
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||||
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (numVisited >= visitedLimit) {
|
||||||
|
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
float[] vector = values.vectorValue();
|
float[] vector = values.vectorValue();
|
||||||
float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
|
float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
|
||||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||||
|
numVisited++;
|
||||||
}
|
}
|
||||||
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
|
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
|
||||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||||
topScoreDocs[i] = topK.pop();
|
topScoreDocs[i] = topK.pop();
|
||||||
}
|
}
|
||||||
return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs);
|
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -140,12 +140,16 @@ public final class HnswGraphSearcher {
|
||||||
int numVisited = 0;
|
int numVisited = 0;
|
||||||
for (int ep : eps) {
|
for (int ep : eps) {
|
||||||
if (visited.getAndSet(ep) == false) {
|
if (visited.getAndSet(ep) == false) {
|
||||||
|
if (numVisited >= visitedLimit) {
|
||||||
|
results.markIncomplete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
|
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
|
||||||
|
numVisited++;
|
||||||
candidates.add(ep, score);
|
candidates.add(ep, score);
|
||||||
if (acceptOrds == null || acceptOrds.get(ep)) {
|
if (acceptOrds == null || acceptOrds.get(ep)) {
|
||||||
results.add(ep, score);
|
results.add(ep, score);
|
||||||
}
|
}
|
||||||
numVisited++;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,18 +159,13 @@ public final class HnswGraphSearcher {
|
||||||
if (results.size() >= topK) {
|
if (results.size() >= topK) {
|
||||||
bound.set(results.topScore());
|
bound.set(results.topScore());
|
||||||
}
|
}
|
||||||
while (candidates.size() > 0) {
|
while (candidates.size() > 0 && results.incomplete() == false) {
|
||||||
// get the best candidate (closest or best scoring)
|
// get the best candidate (closest or best scoring)
|
||||||
float topCandidateScore = candidates.topScore();
|
float topCandidateScore = candidates.topScore();
|
||||||
if (bound.check(topCandidateScore)) {
|
if (bound.check(topCandidateScore)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numVisited >= visitedLimit) {
|
|
||||||
results.markIncomplete();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int topCandidateNode = candidates.pop();
|
int topCandidateNode = candidates.pop();
|
||||||
graph.seek(level, topCandidateNode);
|
graph.seek(level, topCandidateNode);
|
||||||
int friendOrd;
|
int friendOrd;
|
||||||
|
@ -176,6 +175,10 @@ public final class HnswGraphSearcher {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (numVisited >= visitedLimit) {
|
||||||
|
results.markIncomplete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||||
numVisited++;
|
numVisited++;
|
||||||
if (bound.check(score) == false) {
|
if (bound.check(score) == false) {
|
||||||
|
|
|
@ -44,6 +44,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
|
@ -491,7 +492,11 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||||
int dimension = atLeast(5);
|
int dimension = atLeast(5);
|
||||||
int numIters = atLeast(10);
|
int numIters = atLeast(10);
|
||||||
try (Directory d = newDirectory()) {
|
try (Directory d = newDirectory()) {
|
||||||
RandomIndexWriter w = new RandomIndexWriter(random(), d);
|
// Always use the default kNN format to have predictable behavior around when it hits
|
||||||
|
// visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
|
||||||
|
// implementation.
|
||||||
|
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
|
||||||
|
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
|
||||||
for (int i = 0; i < numDocs; i++) {
|
for (int i = 0; i < numDocs; i++) {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(new KnnVectorField("field", randomVector(dimension)));
|
doc.add(new KnnVectorField("field", randomVector(dimension)));
|
||||||
|
@ -690,8 +695,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
|
||||||
throws IOException {
|
|
||||||
throw new UnsupportedOperationException("exact search is not supported");
|
throw new UnsupportedOperationException("exact search is not supported");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -309,8 +309,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
createRandomAcceptOrds(0, vectors.size),
|
createRandomAcceptOrds(0, vectors.size),
|
||||||
visitedLimit);
|
visitedLimit);
|
||||||
assertTrue(nn.incomplete());
|
assertTrue(nn.incomplete());
|
||||||
// The visited count shouldn't be much over the limit
|
// The visited count shouldn't exceed the limit
|
||||||
assertTrue(nn.visitedCount() < visitedLimit + 3);
|
assertTrue(nn.visitedCount() <= visitedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testBoundsCheckerMax() {
|
public void testBoundsCheckerMax() {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Arrays;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.FieldType;
|
import org.apache.lucene.document.FieldType;
|
||||||
|
@ -45,6 +46,7 @@ import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.SortField;
|
import org.apache.lucene.search.SortField;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.store.FSDirectory;
|
import org.apache.lucene.store.FSDirectory;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
@ -821,6 +823,68 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests whether {@link KnnVectorsReader#search} implementations obey the limit on the number of
|
||||||
|
* visited vectors. This test is a best-effort attempt to capture the right behavior, and isn't
|
||||||
|
* meant to define a strict requirement on behavior.
|
||||||
|
*/
|
||||||
|
public void testSearchWithVisitedLimit() throws Exception {
|
||||||
|
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||||
|
String fieldName = "field";
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||||
|
int numDoc = atLeast(300);
|
||||||
|
int dimension = atLeast(10);
|
||||||
|
for (int i = 0; i < numDoc; i++) {
|
||||||
|
float[] value;
|
||||||
|
if (random().nextInt(7) != 3) {
|
||||||
|
// usually index a vector value for a doc
|
||||||
|
value = randomVector(dimension);
|
||||||
|
} else {
|
||||||
|
value = null;
|
||||||
|
}
|
||||||
|
add(iw, fieldName, i, value, VectorSimilarityFunction.EUCLIDEAN);
|
||||||
|
}
|
||||||
|
iw.forceMerge(1);
|
||||||
|
|
||||||
|
// randomly delete some documents
|
||||||
|
for (int i = 0; i < 30; i++) {
|
||||||
|
int idToDelete = random().nextInt(numDoc);
|
||||||
|
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
|
||||||
|
}
|
||||||
|
|
||||||
|
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||||
|
for (LeafReaderContext ctx : reader.leaves()) {
|
||||||
|
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||||
|
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
||||||
|
if (vectorValues == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check the limit is hit when it's very small
|
||||||
|
int k = 5 + random().nextInt(45);
|
||||||
|
int visitedLimit = k + random().nextInt(5);
|
||||||
|
TopDocs results =
|
||||||
|
ctx.reader()
|
||||||
|
.searchNearestVectors(
|
||||||
|
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
|
||||||
|
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
|
||||||
|
assertEquals(visitedLimit, results.totalHits.value);
|
||||||
|
|
||||||
|
// check the limit is not hit when it clearly exceeds the number of vectors
|
||||||
|
k = vectorValues.size();
|
||||||
|
visitedLimit = k + 30;
|
||||||
|
results =
|
||||||
|
ctx.reader()
|
||||||
|
.searchNearestVectors(
|
||||||
|
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
|
||||||
|
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
|
||||||
|
assertTrue(results.totalHits.value <= visitedLimit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes
|
* Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes
|
||||||
* merging, sometimes sorting the index, using an HNSW similarity function so as to also produce a
|
* merging, sometimes sorting the index, using an HNSW similarity function so as to also produce a
|
||||||
|
|
Loading…
Reference in New Issue