mirror of https://github.com/apache/lucene.git
LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (#951)
Instead of collecting hit-by-hit using a `LeafCollector`, we break down the search by instantiating a weight, creating scorers, and checking the underlying iterator. If it is backed by a `BitSet`, we directly update the reference (as we won't be editing the `Bits`). Else we can create a new `BitSet` from the iterator using `BitSet.of`.
This commit is contained in:
parent
9338909373
commit
03846b468e
|
@ -105,6 +105,8 @@ Optimizations
|
|||
|
||||
* LUCENE-10618: Implement BooleanQuery rewrite rules based for minimumShouldMatch. (Fang Hou)
|
||||
|
||||
* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ import org.apache.lucene.index.VectorValues;
|
|||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
||||
|
@ -92,20 +91,20 @@ public class KnnVectorQuery extends Query {
|
|||
public Query rewrite(IndexReader reader) throws IOException {
|
||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||
|
||||
BitSetCollector filterCollector = null;
|
||||
Weight filterWeight = null;
|
||||
if (filter != null) {
|
||||
filterCollector = new BitSetCollector(reader.leaves().size());
|
||||
IndexSearcher indexSearcher = new IndexSearcher(reader);
|
||||
BooleanQuery booleanQuery =
|
||||
new BooleanQuery.Builder()
|
||||
.add(filter, BooleanClause.Occur.FILTER)
|
||||
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
indexSearcher.search(booleanQuery, filterCollector);
|
||||
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||
}
|
||||
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
TopDocs results = searchLeaf(ctx, filterCollector);
|
||||
TopDocs results = searchLeaf(ctx, filterWeight);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
|
@ -121,18 +120,22 @@ public class KnnVectorQuery extends Query {
|
|||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
|
||||
throws IOException {
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
int maxDoc = ctx.reader().maxDoc();
|
||||
|
||||
if (filterCollector == null) {
|
||||
Bits acceptDocs = ctx.reader().getLiveDocs();
|
||||
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
|
||||
} else {
|
||||
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
|
||||
if (filterIterator == null || filterIterator.cost() == 0) {
|
||||
if (filterWeight == null) {
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(ctx);
|
||||
if (scorer == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
|
||||
|
||||
if (filterIterator.cost() <= k) {
|
||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||
// must always visit at least k documents
|
||||
|
@ -140,10 +143,7 @@ public class KnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
Bits acceptDocs =
|
||||
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
|
||||
int visitedLimit = (int) filterIterator.cost();
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
|
||||
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost());
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
return results;
|
||||
} else {
|
||||
|
@ -151,6 +151,23 @@ public class KnnVectorQuery extends Query {
|
|||
return exactSearch(ctx, filterIterator);
|
||||
}
|
||||
}
|
||||
|
||||
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
|
||||
throws IOException {
|
||||
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
|
||||
// If we already have a BitSet and no deletions, reuse the BitSet
|
||||
return bitSetIterator.getBitSet();
|
||||
} else {
|
||||
// Create a new BitSet from matching and live docs
|
||||
FilteredDocIdSetIterator filterIterator =
|
||||
new FilteredDocIdSetIterator(iterator) {
|
||||
@Override
|
||||
protected boolean match(int doc) {
|
||||
return liveDocs == null || liveDocs.get(doc);
|
||||
}
|
||||
};
|
||||
return BitSet.of(filterIterator, maxDoc);
|
||||
}
|
||||
}
|
||||
|
||||
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
|
||||
|
@ -202,47 +219,6 @@ public class KnnVectorQuery extends Query {
|
|||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
|
||||
private static class BitSetCollector extends SimpleCollector {
|
||||
|
||||
private final BitSet[] bitSets;
|
||||
private final int[] cost;
|
||||
private int ord;
|
||||
|
||||
private BitSetCollector(int numLeaves) {
|
||||
this.bitSets = new BitSet[numLeaves];
|
||||
this.cost = new int[bitSets.length];
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
|
||||
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
|
||||
* null.
|
||||
*/
|
||||
public BitSetIterator getIterator(int contextOrd) {
|
||||
if (bitSets[contextOrd] == null) {
|
||||
return null;
|
||||
}
|
||||
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void collect(int doc) throws IOException {
|
||||
bitSets[ord].set(doc);
|
||||
cost[ord]++;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doSetNextReader(LeafReaderContext context) throws IOException {
|
||||
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
|
||||
ord = context.ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.apache.lucene.search.ScoreMode scoreMode() {
|
||||
return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
|
||||
}
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
int len = topK.scoreDocs.length;
|
||||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
||||
|
|
|
@ -45,7 +45,10 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/** TestKnnVectorQuery tests KnnVectorQuery. */
|
||||
|
@ -699,6 +702,36 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that KnnVectorQuery optimizes the case where the filter query is backed by {@link
|
||||
* BitSetIterator}.
|
||||
*/
|
||||
public void testBitSetQuery() throws IOException {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, iwc)) {
|
||||
final int numDocs = 100;
|
||||
final int dim = 30;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
d.add(new KnnVectorField("vector", randomVector(dim)));
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
|
||||
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
|
||||
expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
new KnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
|
@ -797,4 +830,54 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
return in.getCoreCacheHelper();
|
||||
}
|
||||
}
|
||||
|
||||
private static class ThrowingBitSetQuery extends Query {
|
||||
|
||||
private final FixedBitSet docs;
|
||||
|
||||
ThrowingBitSetQuery(FixedBitSet docs) {
|
||||
this.docs = docs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
return new ConstantScoreWeight(this, boost) {
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||
BitSetIterator bitSetIterator =
|
||||
new BitSetIterator(docs, docs.approximateCardinality()) {
|
||||
@Override
|
||||
public BitSet getBitSet() {
|
||||
throw new UnsupportedOperationException("reusing BitSet is not supported");
|
||||
}
|
||||
};
|
||||
return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "throwingBitSetQuery";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + docs.hashCode();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue