From 012b959b052e97aecdd0bec3898f63521dc6b444 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 5 Mar 2024 09:02:49 -0500 Subject: [PATCH] Add mult-leaf optimizations for diversify children collector (#13121) This adds multi-leaf optimizations for diversified children collector. This means as children vectors are collected within a block join, we can share information between leaves to speed up vector search. To make this happen, I refactored the multi-leaf collector slightly. Now, instead of inheriting from TopKnnCollector, we inject a inner collector. --- lucene/CHANGES.txt | 3 + .../lucene/search/AbstractKnnCollector.java | 2 + .../apache/lucene/search/TopKnnCollector.java | 5 ++ .../search/VectorSimilarityCollector.java | 5 ++ ...lector.java => MultiLeafKnnCollector.java} | 71 ++++++++++++------- .../search/knn/TopKnnCollectorManager.java | 6 +- ...iversifyingChildrenByteKnnVectorQuery.java | 8 +-- ...versifyingChildrenFloatKnnVectorQuery.java | 2 +- ...versifyingNearestChildrenKnnCollector.java | 5 ++ ...ingNearestChildrenKnnCollectorManager.java | 22 ++++-- 10 files changed, 92 insertions(+), 37 deletions(-) rename lucene/core/src/java/org/apache/lucene/search/knn/{MultiLeafTopKnnCollector.java => MultiLeafKnnCollector.java} (64%) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 578fe04456f..2d346c4167d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -215,6 +215,9 @@ Optimizations * GITHUB#13085: Remove unnecessary toString() / substring() calls to save some String allocations (Dmitry Cherniachenko) +* GITHUB#13121: Speedup multi-segment HNSW graph search for diversifying child kNN queries. Builds on GITHUB#12962. + (Ben Trent) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java index a75465679f4..0b1b18d1dbd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java @@ -61,6 +61,8 @@ public abstract class AbstractKnnCollector implements KnnCollector { @Override public abstract boolean collect(int docId, float similarity); + public abstract int numCollected(); + @Override public abstract float minCompetitiveSimilarity(); diff --git a/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java index 59d2b2abe39..cef8eb18024 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java @@ -63,6 +63,11 @@ public class TopKnnCollector extends AbstractKnnCollector { return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); } + @Override + public int numCollected() { + return queue.size(); + } + @Override public String toString() { return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]"; diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java index 6005f3ebef5..e7e10c67365 100644 --- a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java @@ -75,4 +75,9 @@ class VectorSimilarityCollector extends AbstractKnnCollector { return new TopDocs( new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new)); } + + @Override + public int numCollected() { + return scoreDocList.size(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafTopKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafKnnCollector.java similarity index 64% rename from lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafTopKnnCollector.java rename to lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafKnnCollector.java index 782a4059b20..5f7b26f95d4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafTopKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafKnnCollector.java @@ -17,20 +17,19 @@ package org.apache.lucene.search.knn; -import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.AbstractKnnCollector; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopKnnCollector; -import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.hnsw.BlockingFloatHeap; import org.apache.lucene.util.hnsw.FloatHeap; /** - * MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results + * MultiLeafKnnCollector is a specific KnnCollector that can exchange the top collected results * across segments through a shared global queue. * * @lucene.experimental */ -public final class MultiLeafTopKnnCollector extends TopKnnCollector { +public final class MultiLeafKnnCollector implements KnnCollector { // greediness of globally non-competitive search: (0,1] private static final float DEFAULT_GREEDINESS = 0.9f; @@ -46,23 +45,55 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector { private final int interval = 0xff; // 255 private boolean kResultsCollected = false; private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY; + private final AbstractKnnCollector subCollector; /** + * Create a new MultiLeafKnnCollector. + * * @param k the number of neighbors to collect - * @param visitLimit how many vector nodes the results are allowed to visit + * @param globalSimilarityQueue the global queue of the highest similarities collected so far + * across all segments + * @param subCollector the local collector */ - public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) { - super(k, visitLimit); + public MultiLeafKnnCollector( + int k, BlockingFloatHeap globalSimilarityQueue, AbstractKnnCollector subCollector) { this.greediness = DEFAULT_GREEDINESS; + this.subCollector = subCollector; this.globalSimilarityQueue = globalSimilarityQueue; this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k))); this.updatesQueue = new FloatHeap(k); } + @Override + public boolean earlyTerminated() { + return subCollector.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + subCollector.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return subCollector.visitedCount(); + } + + @Override + public long visitLimit() { + return subCollector.visitLimit(); + } + + @Override + public int k() { + return subCollector.k(); + } + @Override public boolean collect(int docId, float similarity) { - boolean localSimUpdated = queue.insertWithOverflow(docId, similarity); - boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k()); + boolean localSimUpdated = subCollector.collect(docId, similarity); + boolean firstKResultsCollected = + (kResultsCollected == false && subCollector.numCollected() == k()); if (firstKResultsCollected) { kResultsCollected = true; } @@ -71,7 +102,7 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector { if (kResultsCollected) { // as we've collected k results, we can start do periodic updates with the global queue - if (firstKResultsCollected || (visitedCount & interval) == 0) { + if (firstKResultsCollected || (subCollector.visitedCount() & interval) == 0) { cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap()); updatesQueue.clear(); globalSimUpdated = true; @@ -85,26 +116,18 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector { if (kResultsCollected == false) { return Float.NEGATIVE_INFINITY; } - return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim)); + return Math.max( + subCollector.minCompetitiveSimilarity(), + Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim)); } @Override public TopDocs topDocs() { - assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed"; - ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()]; - for (int i = 1; i <= scoreDocs.length; i++) { - scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore()); - queue.pop(); - } - TotalHits.Relation relation = - earlyTerminated() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); + return subCollector.topDocs(); } @Override public String toString() { - return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]"; + return "MultiLeafKnnCollector[subCollector=" + subCollector + "]"; } } diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java index df4431df5fc..a522d568078 100644 --- a/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java @@ -20,6 +20,7 @@ package org.apache.lucene.search.knn; import java.io.IOException; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.hnsw.BlockingFloatHeap; @@ -48,12 +49,11 @@ public class TopKnnCollectorManager implements KnnCollectorManager { * @param context the leaf reader context */ @Override - public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context) - throws IOException { + public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException { if (globalScoreQueue == null) { return new TopKnnCollector(k, visitedLimit); } else { - return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue); + return new MultiLeafKnnCollector(k, globalScoreQueue, new TopKnnCollector(k, visitedLimit)); } } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index dff70f31bac..d3e9deac7a6 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -126,7 +126,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { @Override protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { - return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter); + return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher); } @Override @@ -136,12 +136,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { - BitSet parentBitSet = parentsFilter.getBitSet(context); - if (parentBitSet == null) { + KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context); + if (collector == null) { return NO_RESULTS; } - KnnCollector collector = - new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet); context.reader().searchNearestVectors(field, query, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index a84d809ac6c..0520f180025 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -126,7 +126,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery @Override protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { - return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter); + return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher); } @Override diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java index 9a5882d3a2d..085c163847a 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java @@ -96,6 +96,11 @@ class DiversifyingNearestChildrenKnnCollector extends AbstractKnnCollector { return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); } + @Override + public int numCollected() { + return heap.size(); + } + /** * This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead * of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java index 8e8a54eedfc..6450af97d0d 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java @@ -19,8 +19,12 @@ package org.apache.lucene.search.join; import java.io.IOException; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.MultiLeafKnnCollector; import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.hnsw.BlockingFloatHeap; /** * DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link @@ -32,6 +36,7 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec private final int k; // filter identifying the parent documents. private final BitSetProducer parentsFilter; + private final BlockingFloatHeap globalScoreQueue; /** * Constructor @@ -39,9 +44,12 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec * @param k - the number of top k vectors to collect * @param parentsFilter Filter identifying the parent documents. */ - public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) { + public DiversifyingNearestChildrenKnnCollectorManager( + int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher) { this.k = k; this.parentsFilter = parentsFilter; + this.globalScoreQueue = + indexSearcher.getIndexReader().leaves().size() > 1 ? new BlockingFloatHeap(k) : null; } /** @@ -51,12 +59,18 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec * @param context the leaf reader context */ @Override - public DiversifyingNearestChildrenKnnCollector newCollector( - int visitedLimit, LeafReaderContext context) throws IOException { + public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException { BitSet parentBitSet = parentsFilter.getBitSet(context); if (parentBitSet == null) { return null; } - return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet); + if (globalScoreQueue == null) { + return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet); + } else { + return new MultiLeafKnnCollector( + k, + globalScoreQueue, + new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet)); + } } }