Add mult-leaf optimizations for diversify children collector (#13121)

This adds multi-leaf optimizations for diversified children collector. This means as children vectors are collected within a block join, we can share information between leaves to speed up vector search.

To make this happen, I refactored the multi-leaf collector slightly. Now, instead of inheriting from TopKnnCollector, we inject a inner collector.
This commit is contained in:
Benjamin Trent 2024-03-05 09:02:49 -05:00 committed by GitHub
parent 51122f8b2e
commit 012b959b05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 92 additions and 37 deletions

View File

@ -215,6 +215,9 @@ Optimizations
* GITHUB#13085: Remove unnecessary toString() / substring() calls to save some String allocations (Dmitry Cherniachenko) * GITHUB#13085: Remove unnecessary toString() / substring() calls to save some String allocations (Dmitry Cherniachenko)
* GITHUB#13121: Speedup multi-segment HNSW graph search for diversifying child kNN queries. Builds on GITHUB#12962.
(Ben Trent)
Bug Fixes Bug Fixes
--------------------- ---------------------

View File

@ -61,6 +61,8 @@ public abstract class AbstractKnnCollector implements KnnCollector {
@Override @Override
public abstract boolean collect(int docId, float similarity); public abstract boolean collect(int docId, float similarity);
public abstract int numCollected();
@Override @Override
public abstract float minCompetitiveSimilarity(); public abstract float minCompetitiveSimilarity();

View File

@ -63,6 +63,11 @@ public class TopKnnCollector extends AbstractKnnCollector {
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
} }
@Override
public int numCollected() {
return queue.size();
}
@Override @Override
public String toString() { public String toString() {
return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]"; return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";

View File

@ -75,4 +75,9 @@ class VectorSimilarityCollector extends AbstractKnnCollector {
return new TopDocs( return new TopDocs(
new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new)); new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new));
} }
@Override
public int numCollected() {
return scoreDocList.size();
}
} }

View File

@ -17,20 +17,19 @@
package org.apache.lucene.search.knn; package org.apache.lucene.search.knn;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.AbstractKnnCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.hnsw.BlockingFloatHeap; import org.apache.lucene.util.hnsw.BlockingFloatHeap;
import org.apache.lucene.util.hnsw.FloatHeap; import org.apache.lucene.util.hnsw.FloatHeap;
/** /**
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results * MultiLeafKnnCollector is a specific KnnCollector that can exchange the top collected results
* across segments through a shared global queue. * across segments through a shared global queue.
* *
* @lucene.experimental * @lucene.experimental
*/ */
public final class MultiLeafTopKnnCollector extends TopKnnCollector { public final class MultiLeafKnnCollector implements KnnCollector {
// greediness of globally non-competitive search: (0,1] // greediness of globally non-competitive search: (0,1]
private static final float DEFAULT_GREEDINESS = 0.9f; private static final float DEFAULT_GREEDINESS = 0.9f;
@ -46,23 +45,55 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
private final int interval = 0xff; // 255 private final int interval = 0xff; // 255
private boolean kResultsCollected = false; private boolean kResultsCollected = false;
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY; private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;
private final AbstractKnnCollector subCollector;
/** /**
* Create a new MultiLeafKnnCollector.
*
* @param k the number of neighbors to collect * @param k the number of neighbors to collect
* @param visitLimit how many vector nodes the results are allowed to visit * @param globalSimilarityQueue the global queue of the highest similarities collected so far
* across all segments
* @param subCollector the local collector
*/ */
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) { public MultiLeafKnnCollector(
super(k, visitLimit); int k, BlockingFloatHeap globalSimilarityQueue, AbstractKnnCollector subCollector) {
this.greediness = DEFAULT_GREEDINESS; this.greediness = DEFAULT_GREEDINESS;
this.subCollector = subCollector;
this.globalSimilarityQueue = globalSimilarityQueue; this.globalSimilarityQueue = globalSimilarityQueue;
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k))); this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
this.updatesQueue = new FloatHeap(k); this.updatesQueue = new FloatHeap(k);
} }
@Override
public boolean earlyTerminated() {
return subCollector.earlyTerminated();
}
@Override
public void incVisitedCount(int count) {
subCollector.incVisitedCount(count);
}
@Override
public long visitedCount() {
return subCollector.visitedCount();
}
@Override
public long visitLimit() {
return subCollector.visitLimit();
}
@Override
public int k() {
return subCollector.k();
}
@Override @Override
public boolean collect(int docId, float similarity) { public boolean collect(int docId, float similarity) {
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity); boolean localSimUpdated = subCollector.collect(docId, similarity);
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k()); boolean firstKResultsCollected =
(kResultsCollected == false && subCollector.numCollected() == k());
if (firstKResultsCollected) { if (firstKResultsCollected) {
kResultsCollected = true; kResultsCollected = true;
} }
@ -71,7 +102,7 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
if (kResultsCollected) { if (kResultsCollected) {
// as we've collected k results, we can start do periodic updates with the global queue // as we've collected k results, we can start do periodic updates with the global queue
if (firstKResultsCollected || (visitedCount & interval) == 0) { if (firstKResultsCollected || (subCollector.visitedCount() & interval) == 0) {
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap()); cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
updatesQueue.clear(); updatesQueue.clear();
globalSimUpdated = true; globalSimUpdated = true;
@ -85,26 +116,18 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
if (kResultsCollected == false) { if (kResultsCollected == false) {
return Float.NEGATIVE_INFINITY; return Float.NEGATIVE_INFINITY;
} }
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim)); return Math.max(
subCollector.minCompetitiveSimilarity(),
Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
} }
@Override @Override
public TopDocs topDocs() { public TopDocs topDocs() {
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed"; return subCollector.topDocs();
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (int i = 1; i <= scoreDocs.length; i++) {
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
queue.pop();
}
TotalHits.Relation relation =
earlyTerminated()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
} }
@Override @Override
public String toString() { public String toString() {
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]"; return "MultiLeafKnnCollector[subCollector=" + subCollector + "]";
} }
} }

View File

@ -20,6 +20,7 @@ package org.apache.lucene.search.knn;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.hnsw.BlockingFloatHeap; import org.apache.lucene.util.hnsw.BlockingFloatHeap;
@ -48,12 +49,11 @@ public class TopKnnCollectorManager implements KnnCollectorManager {
* @param context the leaf reader context * @param context the leaf reader context
*/ */
@Override @Override
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context) public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
throws IOException {
if (globalScoreQueue == null) { if (globalScoreQueue == null) {
return new TopKnnCollector(k, visitedLimit); return new TopKnnCollector(k, visitedLimit);
} else { } else {
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue); return new MultiLeafKnnCollector(k, globalScoreQueue, new TopKnnCollector(k, visitedLimit));
} }
} }
} }

View File

@ -126,7 +126,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
@Override @Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter); return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
} }
@Override @Override
@ -136,12 +136,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
int visitedLimit, int visitedLimit,
KnnCollectorManager knnCollectorManager) KnnCollectorManager knnCollectorManager)
throws IOException { throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context); KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (parentBitSet == null) { if (collector == null) {
return NO_RESULTS; return NO_RESULTS;
} }
KnnCollector collector =
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
context.reader().searchNearestVectors(field, query, collector, acceptDocs); context.reader().searchNearestVectors(field, query, collector, acceptDocs);
return collector.topDocs(); return collector.topDocs();
} }

View File

@ -126,7 +126,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
@Override @Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter); return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
} }
@Override @Override

View File

@ -96,6 +96,11 @@ class DiversifyingNearestChildrenKnnCollector extends AbstractKnnCollector {
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
} }
@Override
public int numCollected() {
return heap.size();
}
/** /**
* This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead * This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead
* of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this * of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this

View File

@ -19,8 +19,12 @@ package org.apache.lucene.search.join;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.MultiLeafKnnCollector;
import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
/** /**
* DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link * DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link
@ -32,6 +36,7 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
private final int k; private final int k;
// filter identifying the parent documents. // filter identifying the parent documents.
private final BitSetProducer parentsFilter; private final BitSetProducer parentsFilter;
private final BlockingFloatHeap globalScoreQueue;
/** /**
* Constructor * Constructor
@ -39,9 +44,12 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
* @param k - the number of top k vectors to collect * @param k - the number of top k vectors to collect
* @param parentsFilter Filter identifying the parent documents. * @param parentsFilter Filter identifying the parent documents.
*/ */
public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) { public DiversifyingNearestChildrenKnnCollectorManager(
int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher) {
this.k = k; this.k = k;
this.parentsFilter = parentsFilter; this.parentsFilter = parentsFilter;
this.globalScoreQueue =
indexSearcher.getIndexReader().leaves().size() > 1 ? new BlockingFloatHeap(k) : null;
} }
/** /**
@ -51,12 +59,18 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
* @param context the leaf reader context * @param context the leaf reader context
*/ */
@Override @Override
public DiversifyingNearestChildrenKnnCollector newCollector( public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
int visitedLimit, LeafReaderContext context) throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context); BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) { if (parentBitSet == null) {
return null; return null;
} }
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet); if (globalScoreQueue == null) {
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
} else {
return new MultiLeafKnnCollector(
k,
globalScoreQueue,
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet));
}
} }
} }