Add mult-leaf optimizations for diversify children collector (#13121)

This adds multi-leaf optimizations for diversified children collector. This means as children vectors are collected within a block join, we can share information between leaves to speed up vector search.

To make this happen, I refactored the multi-leaf collector slightly. Now, instead of inheriting from TopKnnCollector, we inject a inner collector.
This commit is contained in:
Benjamin Trent 2024-03-05 09:02:49 -05:00 committed by GitHub
parent 51122f8b2e
commit 012b959b05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 92 additions and 37 deletions

View File

@ -215,6 +215,9 @@ Optimizations
* GITHUB#13085: Remove unnecessary toString() / substring() calls to save some String allocations (Dmitry Cherniachenko)
* GITHUB#13121: Speedup multi-segment HNSW graph search for diversifying child kNN queries. Builds on GITHUB#12962.
(Ben Trent)
Bug Fixes
---------------------

View File

@ -61,6 +61,8 @@ public abstract class AbstractKnnCollector implements KnnCollector {
@Override
public abstract boolean collect(int docId, float similarity);
public abstract int numCollected();
@Override
public abstract float minCompetitiveSimilarity();

View File

@ -63,6 +63,11 @@ public class TopKnnCollector extends AbstractKnnCollector {
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
}
@Override
public int numCollected() {
return queue.size();
}
@Override
public String toString() {
return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";

View File

@ -75,4 +75,9 @@ class VectorSimilarityCollector extends AbstractKnnCollector {
return new TopDocs(
new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new));
}
@Override
public int numCollected() {
return scoreDocList.size();
}
}

View File

@ -17,20 +17,19 @@
package org.apache.lucene.search.knn;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.AbstractKnnCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
import org.apache.lucene.util.hnsw.FloatHeap;
/**
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results
* MultiLeafKnnCollector is a specific KnnCollector that can exchange the top collected results
* across segments through a shared global queue.
*
* @lucene.experimental
*/
public final class MultiLeafTopKnnCollector extends TopKnnCollector {
public final class MultiLeafKnnCollector implements KnnCollector {
// greediness of globally non-competitive search: (0,1]
private static final float DEFAULT_GREEDINESS = 0.9f;
@ -46,23 +45,55 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
private final int interval = 0xff; // 255
private boolean kResultsCollected = false;
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;
private final AbstractKnnCollector subCollector;
/**
* Create a new MultiLeafKnnCollector.
*
* @param k the number of neighbors to collect
* @param visitLimit how many vector nodes the results are allowed to visit
* @param globalSimilarityQueue the global queue of the highest similarities collected so far
* across all segments
* @param subCollector the local collector
*/
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) {
super(k, visitLimit);
public MultiLeafKnnCollector(
int k, BlockingFloatHeap globalSimilarityQueue, AbstractKnnCollector subCollector) {
this.greediness = DEFAULT_GREEDINESS;
this.subCollector = subCollector;
this.globalSimilarityQueue = globalSimilarityQueue;
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
this.updatesQueue = new FloatHeap(k);
}
@Override
public boolean earlyTerminated() {
return subCollector.earlyTerminated();
}
@Override
public void incVisitedCount(int count) {
subCollector.incVisitedCount(count);
}
@Override
public long visitedCount() {
return subCollector.visitedCount();
}
@Override
public long visitLimit() {
return subCollector.visitLimit();
}
@Override
public int k() {
return subCollector.k();
}
@Override
public boolean collect(int docId, float similarity) {
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity);
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k());
boolean localSimUpdated = subCollector.collect(docId, similarity);
boolean firstKResultsCollected =
(kResultsCollected == false && subCollector.numCollected() == k());
if (firstKResultsCollected) {
kResultsCollected = true;
}
@ -71,7 +102,7 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
if (kResultsCollected) {
// as we've collected k results, we can start do periodic updates with the global queue
if (firstKResultsCollected || (visitedCount & interval) == 0) {
if (firstKResultsCollected || (subCollector.visitedCount() & interval) == 0) {
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
updatesQueue.clear();
globalSimUpdated = true;
@ -85,26 +116,18 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
if (kResultsCollected == false) {
return Float.NEGATIVE_INFINITY;
}
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
return Math.max(
subCollector.minCompetitiveSimilarity(),
Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
}
@Override
public TopDocs topDocs() {
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed";
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (int i = 1; i <= scoreDocs.length; i++) {
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
queue.pop();
}
TotalHits.Relation relation =
earlyTerminated()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
return subCollector.topDocs();
}
@Override
public String toString() {
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
return "MultiLeafKnnCollector[subCollector=" + subCollector + "]";
}
}

View File

@ -20,6 +20,7 @@ package org.apache.lucene.search.knn;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
@ -48,12 +49,11 @@ public class TopKnnCollectorManager implements KnnCollectorManager {
* @param context the leaf reader context
*/
@Override
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context)
throws IOException {
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
if (globalScoreQueue == null) {
return new TopKnnCollector(k, visitedLimit);
} else {
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue);
return new MultiLeafKnnCollector(k, globalScoreQueue, new TopKnnCollector(k, visitedLimit));
}
}
}

View File

@ -126,7 +126,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
}
@Override
@ -136,12 +136,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
}
KnnCollector collector =
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
return collector.topDocs();
}

View File

@ -126,7 +126,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
}
@Override

View File

@ -96,6 +96,11 @@ class DiversifyingNearestChildrenKnnCollector extends AbstractKnnCollector {
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
}
@Override
public int numCollected() {
return heap.size();
}
/**
* This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead
* of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this

View File

@ -19,8 +19,12 @@ package org.apache.lucene.search.join;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.MultiLeafKnnCollector;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
/**
* DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link
@ -32,6 +36,7 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
private final int k;
// filter identifying the parent documents.
private final BitSetProducer parentsFilter;
private final BlockingFloatHeap globalScoreQueue;
/**
* Constructor
@ -39,9 +44,12 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
* @param k - the number of top k vectors to collect
* @param parentsFilter Filter identifying the parent documents.
*/
public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) {
public DiversifyingNearestChildrenKnnCollectorManager(
int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher) {
this.k = k;
this.parentsFilter = parentsFilter;
this.globalScoreQueue =
indexSearcher.getIndexReader().leaves().size() > 1 ? new BlockingFloatHeap(k) : null;
}
/**
@ -51,12 +59,18 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
* @param context the leaf reader context
*/
@Override
public DiversifyingNearestChildrenKnnCollector newCollector(
int visitedLimit, LeafReaderContext context) throws IOException {
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return null;
}
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
if (globalScoreQueue == null) {
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
} else {
return new MultiLeafKnnCollector(
k,
globalScoreQueue,
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet));
}
}
}