LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (#951)

Instead of collecting hit-by-hit using a `LeafCollector`, we break down the
search by instantiating a weight, creating scorers, and checking the underlying
iterator. If it is backed by a `BitSet`, we directly update the reference (as
we won't be editing the `Bits`). Else we can create a new `BitSet` from the
iterator using `BitSet.of`.
This commit is contained in:
Kaival Parikh 2022-06-27 12:22:52 +05:30 committed by GitHub
parent 9338909373
commit 03846b468e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 72 deletions

View File

@ -105,6 +105,8 @@ Optimizations
* LUCENE-10618: Implement BooleanQuery rewrite rules based for minimumShouldMatch. (Fang Hou)
* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh)
Bug Fixes
---------------------

View File

@ -32,7 +32,6 @@ import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@ -92,20 +91,20 @@ public class KnnVectorQuery extends Query {
public Query rewrite(IndexReader reader) throws IOException {
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
BitSetCollector filterCollector = null;
Weight filterWeight = null;
if (filter != null) {
filterCollector = new BitSetCollector(reader.leaves().size());
IndexSearcher indexSearcher = new IndexSearcher(reader);
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
indexSearcher.search(booleanQuery, filterCollector);
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterCollector);
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
@ -121,35 +120,53 @@ public class KnnVectorQuery extends Query {
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
throws IOException {
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
if (filterCollector == null) {
Bits acceptDocs = ctx.reader().getLiveDocs();
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
}
Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return NO_RESULTS;
}
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
if (filterIterator.cost() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, filterIterator);
}
// Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost());
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
if (filterIterator == null || filterIterator.cost() == 0) {
return NO_RESULTS;
}
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator);
}
}
if (filterIterator.cost() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, filterIterator);
}
// Perform the approximate kNN search
Bits acceptDocs =
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
int visitedLimit = (int) filterIterator.cost();
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator);
}
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}
@ -202,47 +219,6 @@ public class KnnVectorQuery extends Query {
return new TopDocs(totalHits, topScoreDocs);
}
private static class BitSetCollector extends SimpleCollector {
private final BitSet[] bitSets;
private final int[] cost;
private int ord;
private BitSetCollector(int numLeaves) {
this.bitSets = new BitSet[numLeaves];
this.cost = new int[bitSets.length];
}
/**
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
* null.
*/
public BitSetIterator getIterator(int contextOrd) {
if (bitSets[contextOrd] == null) {
return null;
}
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
}
@Override
public void collect(int doc) throws IOException {
bitSets[ord].set(doc);
cost[ord]++;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
ord = context.ord;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));

View File

@ -45,7 +45,10 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
/** TestKnnVectorQuery tests KnnVectorQuery. */
@ -699,6 +702,36 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
/**
* Test that KnnVectorQuery optimizes the case where the filter query is backed by {@link
* BitSetIterator}.
*/
public void testBitSetQuery() throws IOException {
IndexWriterConfig iwc = newIndexWriterConfig();
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, iwc)) {
final int numDocs = 100;
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(new KnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
try (DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
new KnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
}
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
@ -797,4 +830,54 @@ public class TestKnnVectorQuery extends LuceneTestCase {
return in.getCoreCacheHelper();
}
}
private static class ThrowingBitSetQuery extends Query {
private final FixedBitSet docs;
ThrowingBitSetQuery(FixedBitSet docs) {
this.docs = docs;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
BitSetIterator bitSetIterator =
new BitSetIterator(docs, docs.approximateCardinality()) {
@Override
public BitSet getBitSet() {
throw new UnsupportedOperationException("reusing BitSet is not supported");
}
};
return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator);
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}
@Override
public void visit(QueryVisitor visitor) {}
@Override
public String toString(String field) {
return "throwingBitSetQuery";
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs);
}
@Override
public int hashCode() {
return 31 * classHash() + docs.hashCode();
}
}
}