diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index ac16876c098..79c7266bb37 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -105,6 +105,8 @@ Optimizations * LUCENE-10618: Implement BooleanQuery rewrite rules based for minimumShouldMatch. (Fang Hou) +* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java index f8743d8e75b..ffd73d0ab68 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java @@ -32,7 +32,6 @@ import org.apache.lucene.index.VectorValues; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.FixedBitSet; /** * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. @@ -92,20 +91,20 @@ public class KnnVectorQuery extends Query { public Query rewrite(IndexReader reader) throws IOException { TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()]; - BitSetCollector filterCollector = null; + Weight filterWeight = null; if (filter != null) { - filterCollector = new BitSetCollector(reader.leaves().size()); IndexSearcher indexSearcher = new IndexSearcher(reader); BooleanQuery booleanQuery = new BooleanQuery.Builder() .add(filter, BooleanClause.Occur.FILTER) .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) .build(); - indexSearcher.search(booleanQuery, filterCollector); + Query rewritten = indexSearcher.rewrite(booleanQuery); + filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); } for (LeafReaderContext ctx : reader.leaves()) { - TopDocs results = searchLeaf(ctx, filterCollector); + TopDocs results = searchLeaf(ctx, filterWeight); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { scoreDoc.doc += ctx.docBase; @@ -121,35 +120,53 @@ public class KnnVectorQuery extends Query { return createRewrittenQuery(reader, topK); } - private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector) - throws IOException { + private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException { + Bits liveDocs = ctx.reader().getLiveDocs(); + int maxDoc = ctx.reader().maxDoc(); - if (filterCollector == null) { - Bits acceptDocs = ctx.reader().getLiveDocs(); - return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE); + if (filterWeight == null) { + return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE); + } + + Scorer scorer = filterWeight.scorer(ctx); + if (scorer == null) { + return NO_RESULTS; + } + + BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc); + BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality()); + + if (filterIterator.cost() <= k) { + // If there are <= k possible matches, short-circuit and perform exact search, since HNSW + // must always visit at least k documents + return exactSearch(ctx, filterIterator); + } + + // Perform the approximate kNN search + TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost()); + if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) { + return results; } else { - BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord); - if (filterIterator == null || filterIterator.cost() == 0) { - return NO_RESULTS; - } + // We stopped the kNN search because it visited too many nodes, so fall back to exact search + return exactSearch(ctx, filterIterator); + } + } - if (filterIterator.cost() <= k) { - // If there are <= k possible matches, short-circuit and perform exact search, since HNSW - // must always visit at least k documents - return exactSearch(ctx, filterIterator); - } - - // Perform the approximate kNN search - Bits acceptDocs = - filterIterator.getBitSet(); // The filter iterator already incorporates live docs - int visitedLimit = (int) filterIterator.cost(); - TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit); - if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) { - return results; - } else { - // We stopped the kNN search because it visited too many nodes, so fall back to exact search - return exactSearch(ctx, filterIterator); - } + private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) + throws IOException { + if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return bitSetIterator.getBitSet(); + } else { + // Create a new BitSet from matching and live docs + FilteredDocIdSetIterator filterIterator = + new FilteredDocIdSetIterator(iterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + return BitSet.of(filterIterator, maxDoc); } } @@ -202,47 +219,6 @@ public class KnnVectorQuery extends Query { return new TopDocs(totalHits, topScoreDocs); } - private static class BitSetCollector extends SimpleCollector { - - private final BitSet[] bitSets; - private final int[] cost; - private int ord; - - private BitSetCollector(int numLeaves) { - this.bitSets = new BitSet[numLeaves]; - this.cost = new int[bitSets.length]; - } - - /** - * Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link - * BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return - * null. - */ - public BitSetIterator getIterator(int contextOrd) { - if (bitSets[contextOrd] == null) { - return null; - } - return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]); - } - - @Override - public void collect(int doc) throws IOException { - bitSets[ord].set(doc); - cost[ord]++; - } - - @Override - protected void doSetNextReader(LeafReaderContext context) throws IOException { - bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc()); - ord = context.ord; - } - - @Override - public org.apache.lucene.search.ScoreMode scoreMode() { - return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; - } - } - private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java index ba9e6b5b5a7..74ecf23c292 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java @@ -45,7 +45,10 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.VectorUtil; /** TestKnnVectorQuery tests KnnVectorQuery. */ @@ -699,6 +702,36 @@ public class TestKnnVectorQuery extends LuceneTestCase { } } + /** + * Test that KnnVectorQuery optimizes the case where the filter query is backed by {@link + * BitSetIterator}. + */ + public void testBitSetQuery() throws IOException { + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, iwc)) { + final int numDocs = 100; + final int dim = 30; + for (int i = 0; i < numDocs; ++i) { + Document d = new Document(); + d.add(new KnnVectorField("vector", randomVector(dim))); + w.addDocument(d); + } + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs)); + expectThrows( + UnsupportedOperationException.class, + () -> + searcher.search( + new KnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs)); + } + } + } + /** Creates a new directory and adds documents with the given vectors as kNN vector fields */ private Directory getIndexStore(String field, float[]... contents) throws IOException { Directory indexStore = newDirectory(); @@ -797,4 +830,54 @@ public class TestKnnVectorQuery extends LuceneTestCase { return in.getCoreCacheHelper(); } } + + private static class ThrowingBitSetQuery extends Query { + + private final FixedBitSet docs; + + ThrowingBitSetQuery(FixedBitSet docs) { + this.docs = docs; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + return new ConstantScoreWeight(this, boost) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + BitSetIterator bitSetIterator = + new BitSetIterator(docs, docs.approximateCardinality()) { + @Override + public BitSet getBitSet() { + throw new UnsupportedOperationException("reusing BitSet is not supported"); + } + }; + return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public String toString(String field) { + return "throwingBitSetQuery"; + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs); + } + + @Override + public int hashCode() { + return 31 * classHash() + docs.hashCode(); + } + } }