Utilize exact kNN search when gathering k > numVectors in a segment (#12806)

When requesting for k >= numVectors, it doesn't make sense to go through the HNSW graph. Even without a user supplied filter, we should not explore the HNSW graph if it contains fewer than k vectors.

One scenario where we may still explore the graph if k >= numVectors is when not every document has a vector and there are deleted docs. But, this commit significantly improves things regardless.
This commit is contained in:
Benjamin Trent 2023-11-15 12:56:15 -05:00 committed by GitHub
parent 5afc17d4b5
commit 05a336ea69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 13 deletions

View File

@ -285,6 +285,8 @@ Optimizations
* GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng)
* GITHUB#12806: Utilize exact kNN search when gathering k >= numVectors in a segment (Ben Trent)
Changes in runtime behavior
---------------------

View File

@ -227,12 +227,22 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
}
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
getGraph(fieldEntry),
scorer.getAcceptOrds(acceptDocs));
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
}
@Override
@ -245,12 +255,22 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return;
}
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
getGraph(fieldEntry),
scorer.getAcceptOrds(acceptDocs));
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
}
@Override

View File

@ -17,6 +17,8 @@
package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@ -222,7 +224,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 1);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
@ -779,6 +781,16 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
doc.add(getKnnVectorField(field, contents[i], vectorSimilarityFunction));
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
if (randomBoolean()) {
// Add some documents without a vector
for (int j = 0; j < randomIntBetween(1, 5); j++) {
doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
// Add fields that will be matched by our test filters but won't have vectors
doc.add(new StringField("id", "id" + j, Field.Store.YES));
writer.addDocument(doc);
}
}
}
// Add some documents without a vector
for (int i = 0; i < 5; i++) {