mirror of https://github.com/apache/lucene.git
LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (#951)
Instead of collecting hit-by-hit using a `LeafCollector`, we break down the search by instantiating a weight, creating scorers, and checking the underlying iterator. If it is backed by a `BitSet`, we directly update the reference (as we won't be editing the `Bits`). Else we can create a new `BitSet` from the iterator using `BitSet.of`.
This commit is contained in:
parent
9338909373
commit
03846b468e
|
@ -105,6 +105,8 @@ Optimizations
|
||||||
|
|
||||||
* LUCENE-10618: Implement BooleanQuery rewrite rules based for minimumShouldMatch. (Fang Hou)
|
* LUCENE-10618: Implement BooleanQuery rewrite rules based for minimumShouldMatch. (Fang Hou)
|
||||||
|
|
||||||
|
* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh)
|
||||||
|
|
||||||
Bug Fixes
|
Bug Fixes
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,6 @@ import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.util.BitSet;
|
import org.apache.lucene.util.BitSet;
|
||||||
import org.apache.lucene.util.BitSetIterator;
|
import org.apache.lucene.util.BitSetIterator;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
||||||
|
@ -92,20 +91,20 @@ public class KnnVectorQuery extends Query {
|
||||||
public Query rewrite(IndexReader reader) throws IOException {
|
public Query rewrite(IndexReader reader) throws IOException {
|
||||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||||
|
|
||||||
BitSetCollector filterCollector = null;
|
Weight filterWeight = null;
|
||||||
if (filter != null) {
|
if (filter != null) {
|
||||||
filterCollector = new BitSetCollector(reader.leaves().size());
|
|
||||||
IndexSearcher indexSearcher = new IndexSearcher(reader);
|
IndexSearcher indexSearcher = new IndexSearcher(reader);
|
||||||
BooleanQuery booleanQuery =
|
BooleanQuery booleanQuery =
|
||||||
new BooleanQuery.Builder()
|
new BooleanQuery.Builder()
|
||||||
.add(filter, BooleanClause.Occur.FILTER)
|
.add(filter, BooleanClause.Occur.FILTER)
|
||||||
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||||
.build();
|
.build();
|
||||||
indexSearcher.search(booleanQuery, filterCollector);
|
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||||
|
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (LeafReaderContext ctx : reader.leaves()) {
|
for (LeafReaderContext ctx : reader.leaves()) {
|
||||||
TopDocs results = searchLeaf(ctx, filterCollector);
|
TopDocs results = searchLeaf(ctx, filterWeight);
|
||||||
if (ctx.docBase > 0) {
|
if (ctx.docBase > 0) {
|
||||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||||
scoreDoc.doc += ctx.docBase;
|
scoreDoc.doc += ctx.docBase;
|
||||||
|
@ -121,35 +120,53 @@ public class KnnVectorQuery extends Query {
|
||||||
return createRewrittenQuery(reader, topK);
|
return createRewrittenQuery(reader, topK);
|
||||||
}
|
}
|
||||||
|
|
||||||
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
|
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||||
throws IOException {
|
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||||
|
int maxDoc = ctx.reader().maxDoc();
|
||||||
|
|
||||||
if (filterCollector == null) {
|
if (filterWeight == null) {
|
||||||
Bits acceptDocs = ctx.reader().getLiveDocs();
|
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
|
||||||
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
|
}
|
||||||
|
|
||||||
|
Scorer scorer = filterWeight.scorer(ctx);
|
||||||
|
if (scorer == null) {
|
||||||
|
return NO_RESULTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||||
|
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
|
||||||
|
|
||||||
|
if (filterIterator.cost() <= k) {
|
||||||
|
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||||
|
// must always visit at least k documents
|
||||||
|
return exactSearch(ctx, filterIterator);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the approximate kNN search
|
||||||
|
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost());
|
||||||
|
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||||
|
return results;
|
||||||
} else {
|
} else {
|
||||||
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
|
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||||
if (filterIterator == null || filterIterator.cost() == 0) {
|
return exactSearch(ctx, filterIterator);
|
||||||
return NO_RESULTS;
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filterIterator.cost() <= k) {
|
private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
|
||||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
throws IOException {
|
||||||
// must always visit at least k documents
|
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
|
||||||
return exactSearch(ctx, filterIterator);
|
// If we already have a BitSet and no deletions, reuse the BitSet
|
||||||
}
|
return bitSetIterator.getBitSet();
|
||||||
|
} else {
|
||||||
// Perform the approximate kNN search
|
// Create a new BitSet from matching and live docs
|
||||||
Bits acceptDocs =
|
FilteredDocIdSetIterator filterIterator =
|
||||||
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
|
new FilteredDocIdSetIterator(iterator) {
|
||||||
int visitedLimit = (int) filterIterator.cost();
|
@Override
|
||||||
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
|
protected boolean match(int doc) {
|
||||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
return liveDocs == null || liveDocs.get(doc);
|
||||||
return results;
|
}
|
||||||
} else {
|
};
|
||||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
return BitSet.of(filterIterator, maxDoc);
|
||||||
return exactSearch(ctx, filterIterator);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,47 +219,6 @@ public class KnnVectorQuery extends Query {
|
||||||
return new TopDocs(totalHits, topScoreDocs);
|
return new TopDocs(totalHits, topScoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class BitSetCollector extends SimpleCollector {
|
|
||||||
|
|
||||||
private final BitSet[] bitSets;
|
|
||||||
private final int[] cost;
|
|
||||||
private int ord;
|
|
||||||
|
|
||||||
private BitSetCollector(int numLeaves) {
|
|
||||||
this.bitSets = new BitSet[numLeaves];
|
|
||||||
this.cost = new int[bitSets.length];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
|
|
||||||
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
|
|
||||||
* null.
|
|
||||||
*/
|
|
||||||
public BitSetIterator getIterator(int contextOrd) {
|
|
||||||
if (bitSets[contextOrd] == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void collect(int doc) throws IOException {
|
|
||||||
bitSets[ord].set(doc);
|
|
||||||
cost[ord]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void doSetNextReader(LeafReaderContext context) throws IOException {
|
|
||||||
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
|
|
||||||
ord = context.ord;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public org.apache.lucene.search.ScoreMode scoreMode() {
|
|
||||||
return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||||
int len = topK.scoreDocs.length;
|
int len = topK.scoreDocs.length;
|
||||||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
||||||
|
|
|
@ -45,7 +45,10 @@ import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
import org.apache.lucene.util.BitSet;
|
||||||
|
import org.apache.lucene.util.BitSetIterator;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
/** TestKnnVectorQuery tests KnnVectorQuery. */
|
/** TestKnnVectorQuery tests KnnVectorQuery. */
|
||||||
|
@ -699,6 +702,36 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test that KnnVectorQuery optimizes the case where the filter query is backed by {@link
|
||||||
|
* BitSetIterator}.
|
||||||
|
*/
|
||||||
|
public void testBitSetQuery() throws IOException {
|
||||||
|
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
IndexWriter w = new IndexWriter(dir, iwc)) {
|
||||||
|
final int numDocs = 100;
|
||||||
|
final int dim = 30;
|
||||||
|
for (int i = 0; i < numDocs; ++i) {
|
||||||
|
Document d = new Document();
|
||||||
|
d.add(new KnnVectorField("vector", randomVector(dim)));
|
||||||
|
w.addDocument(d);
|
||||||
|
}
|
||||||
|
w.commit();
|
||||||
|
|
||||||
|
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
|
|
||||||
|
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
|
||||||
|
expectThrows(
|
||||||
|
UnsupportedOperationException.class,
|
||||||
|
() ->
|
||||||
|
searcher.search(
|
||||||
|
new KnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
|
||||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||||
Directory indexStore = newDirectory();
|
Directory indexStore = newDirectory();
|
||||||
|
@ -797,4 +830,54 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||||
return in.getCoreCacheHelper();
|
return in.getCoreCacheHelper();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class ThrowingBitSetQuery extends Query {
|
||||||
|
|
||||||
|
private final FixedBitSet docs;
|
||||||
|
|
||||||
|
ThrowingBitSetQuery(FixedBitSet docs) {
|
||||||
|
this.docs = docs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||||
|
throws IOException {
|
||||||
|
return new ConstantScoreWeight(this, boost) {
|
||||||
|
@Override
|
||||||
|
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||||
|
BitSetIterator bitSetIterator =
|
||||||
|
new BitSetIterator(docs, docs.approximateCardinality()) {
|
||||||
|
@Override
|
||||||
|
public BitSet getBitSet() {
|
||||||
|
throw new UnsupportedOperationException("reusing BitSet is not supported");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isCacheable(LeafReaderContext ctx) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visit(QueryVisitor visitor) {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(String field) {
|
||||||
|
return "throwingBitSetQuery";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object other) {
|
||||||
|
return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return 31 * classHash() + docs.hashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue