mirror of https://github.com/apache/lucene.git
Add mult-leaf optimizations for diversify children collector (#13121)
This adds multi-leaf optimizations for diversified children collector. This means as children vectors are collected within a block join, we can share information between leaves to speed up vector search. To make this happen, I refactored the multi-leaf collector slightly. Now, instead of inheriting from TopKnnCollector, we inject a inner collector.
This commit is contained in:
parent
51122f8b2e
commit
012b959b05
|
@ -215,6 +215,9 @@ Optimizations
|
|||
|
||||
* GITHUB#13085: Remove unnecessary toString() / substring() calls to save some String allocations (Dmitry Cherniachenko)
|
||||
|
||||
* GITHUB#13121: Speedup multi-segment HNSW graph search for diversifying child kNN queries. Builds on GITHUB#12962.
|
||||
(Ben Trent)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -61,6 +61,8 @@ public abstract class AbstractKnnCollector implements KnnCollector {
|
|||
@Override
|
||||
public abstract boolean collect(int docId, float similarity);
|
||||
|
||||
public abstract int numCollected();
|
||||
|
||||
@Override
|
||||
public abstract float minCompetitiveSimilarity();
|
||||
|
||||
|
|
|
@ -63,6 +63,11 @@ public class TopKnnCollector extends AbstractKnnCollector {
|
|||
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCollected() {
|
||||
return queue.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
|
||||
|
|
|
@ -75,4 +75,9 @@ class VectorSimilarityCollector extends AbstractKnnCollector {
|
|||
return new TopDocs(
|
||||
new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCollected() {
|
||||
return scoreDocList.size();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,20 +17,19 @@
|
|||
|
||||
package org.apache.lucene.search.knn;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.AbstractKnnCollector;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
|
||||
import org.apache.lucene.util.hnsw.FloatHeap;
|
||||
|
||||
/**
|
||||
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results
|
||||
* MultiLeafKnnCollector is a specific KnnCollector that can exchange the top collected results
|
||||
* across segments through a shared global queue.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class MultiLeafTopKnnCollector extends TopKnnCollector {
|
||||
public final class MultiLeafKnnCollector implements KnnCollector {
|
||||
|
||||
// greediness of globally non-competitive search: (0,1]
|
||||
private static final float DEFAULT_GREEDINESS = 0.9f;
|
||||
|
@ -46,23 +45,55 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
|
|||
private final int interval = 0xff; // 255
|
||||
private boolean kResultsCollected = false;
|
||||
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;
|
||||
private final AbstractKnnCollector subCollector;
|
||||
|
||||
/**
|
||||
* Create a new MultiLeafKnnCollector.
|
||||
*
|
||||
* @param k the number of neighbors to collect
|
||||
* @param visitLimit how many vector nodes the results are allowed to visit
|
||||
* @param globalSimilarityQueue the global queue of the highest similarities collected so far
|
||||
* across all segments
|
||||
* @param subCollector the local collector
|
||||
*/
|
||||
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) {
|
||||
super(k, visitLimit);
|
||||
public MultiLeafKnnCollector(
|
||||
int k, BlockingFloatHeap globalSimilarityQueue, AbstractKnnCollector subCollector) {
|
||||
this.greediness = DEFAULT_GREEDINESS;
|
||||
this.subCollector = subCollector;
|
||||
this.globalSimilarityQueue = globalSimilarityQueue;
|
||||
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
|
||||
this.updatesQueue = new FloatHeap(k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean earlyTerminated() {
|
||||
return subCollector.earlyTerminated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incVisitedCount(int count) {
|
||||
subCollector.incVisitedCount(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long visitedCount() {
|
||||
return subCollector.visitedCount();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long visitLimit() {
|
||||
return subCollector.visitLimit();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int k() {
|
||||
return subCollector.k();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean collect(int docId, float similarity) {
|
||||
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity);
|
||||
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k());
|
||||
boolean localSimUpdated = subCollector.collect(docId, similarity);
|
||||
boolean firstKResultsCollected =
|
||||
(kResultsCollected == false && subCollector.numCollected() == k());
|
||||
if (firstKResultsCollected) {
|
||||
kResultsCollected = true;
|
||||
}
|
||||
|
@ -71,7 +102,7 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
|
|||
|
||||
if (kResultsCollected) {
|
||||
// as we've collected k results, we can start do periodic updates with the global queue
|
||||
if (firstKResultsCollected || (visitedCount & interval) == 0) {
|
||||
if (firstKResultsCollected || (subCollector.visitedCount() & interval) == 0) {
|
||||
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
|
||||
updatesQueue.clear();
|
||||
globalSimUpdated = true;
|
||||
|
@ -85,26 +116,18 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
|
|||
if (kResultsCollected == false) {
|
||||
return Float.NEGATIVE_INFINITY;
|
||||
}
|
||||
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
|
||||
return Math.max(
|
||||
subCollector.minCompetitiveSimilarity(),
|
||||
Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs topDocs() {
|
||||
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed";
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = 1; i <= scoreDocs.length; i++) {
|
||||
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
|
||||
queue.pop();
|
||||
}
|
||||
TotalHits.Relation relation =
|
||||
earlyTerminated()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
|
||||
return subCollector.topDocs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
|
||||
return "MultiLeafKnnCollector[subCollector=" + subCollector + "]";
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ package org.apache.lucene.search.knn;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
|
||||
|
||||
|
@ -48,12 +49,11 @@ public class TopKnnCollectorManager implements KnnCollectorManager {
|
|||
* @param context the leaf reader context
|
||||
*/
|
||||
@Override
|
||||
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context)
|
||||
throws IOException {
|
||||
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
|
||||
if (globalScoreQueue == null) {
|
||||
return new TopKnnCollector(k, visitedLimit);
|
||||
} else {
|
||||
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue);
|
||||
return new MultiLeafKnnCollector(k, globalScoreQueue, new TopKnnCollector(k, visitedLimit));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
|
||||
@Override
|
||||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
|
||||
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,12 +136,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager)
|
||||
throws IOException {
|
||||
BitSet parentBitSet = parentsFilter.getBitSet(context);
|
||||
if (parentBitSet == null) {
|
||||
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
|
||||
if (collector == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
KnnCollector collector =
|
||||
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
|
||||
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
|
||||
@Override
|
||||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
|
||||
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -96,6 +96,11 @@ class DiversifyingNearestChildrenKnnCollector extends AbstractKnnCollector {
|
|||
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCollected() {
|
||||
return heap.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead
|
||||
* of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this
|
||||
|
|
|
@ -19,8 +19,12 @@ package org.apache.lucene.search.join;
|
|||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.search.knn.MultiLeafKnnCollector;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
|
||||
|
||||
/**
|
||||
* DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link
|
||||
|
@ -32,6 +36,7 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
|
|||
private final int k;
|
||||
// filter identifying the parent documents.
|
||||
private final BitSetProducer parentsFilter;
|
||||
private final BlockingFloatHeap globalScoreQueue;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
|
@ -39,9 +44,12 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
|
|||
* @param k - the number of top k vectors to collect
|
||||
* @param parentsFilter Filter identifying the parent documents.
|
||||
*/
|
||||
public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) {
|
||||
public DiversifyingNearestChildrenKnnCollectorManager(
|
||||
int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher) {
|
||||
this.k = k;
|
||||
this.parentsFilter = parentsFilter;
|
||||
this.globalScoreQueue =
|
||||
indexSearcher.getIndexReader().leaves().size() > 1 ? new BlockingFloatHeap(k) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -51,12 +59,18 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
|
|||
* @param context the leaf reader context
|
||||
*/
|
||||
@Override
|
||||
public DiversifyingNearestChildrenKnnCollector newCollector(
|
||||
int visitedLimit, LeafReaderContext context) throws IOException {
|
||||
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
|
||||
BitSet parentBitSet = parentsFilter.getBitSet(context);
|
||||
if (parentBitSet == null) {
|
||||
return null;
|
||||
}
|
||||
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
|
||||
if (globalScoreQueue == null) {
|
||||
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
|
||||
} else {
|
||||
return new MultiLeafKnnCollector(
|
||||
k,
|
||||
globalScoreQueue,
|
||||
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue