diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java index c23f56bcdc6..2c800994be6 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java @@ -16,7 +16,6 @@ */ package org.apache.lucene.util.hnsw; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.hnsw.HnswGraphBuilder.HNSW_COMPONENT; import java.io.IOException; @@ -24,10 +23,8 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.Lock; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; /** @@ -157,15 +154,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder { BitSet initializedNodes, AtomicInteger workProgress) throws IOException { - super( - scorerSupplier, - M, - beamWidth, - seed, - hnsw, - hnswLock, - new MergeSearcher( - new NeighborQueue(beamWidth, true), hnswLock, new FixedBitSet(hnsw.maxNodeId() + 1))); + super(scorerSupplier, M, beamWidth, seed, hnsw, hnswLock); this.workProgress = workProgress; this.initializedNodes = initializedNodes; } @@ -204,44 +193,4 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder { super.addGraphNode(node); } } - - /** - * This searcher will obtain the lock and make a copy of neighborArray when seeking the graph such - * that concurrent modification of the graph will not impact the search - */ - private static class MergeSearcher extends HnswGraphSearcher { - private final HnswLock hnswLock; - private int[] nodeBuffer; - private int upto; - private int size; - - private MergeSearcher(NeighborQueue candidates, HnswLock hnswLock, BitSet visited) { - super(candidates, visited); - this.hnswLock = hnswLock; - } - - @Override - void graphSeek(HnswGraph graph, int level, int targetNode) { - Lock lock = hnswLock.read(level, targetNode); - try { - NeighborArray neighborArray = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); - if (nodeBuffer == null || nodeBuffer.length < neighborArray.size()) { - nodeBuffer = new int[neighborArray.size()]; - } - size = neighborArray.size(); - System.arraycopy(neighborArray.nodes(), 0, nodeBuffer, 0, size); - } finally { - lock.unlock(); - } - upto = -1; - } - - @Override - int graphNextNeighbor(HnswGraph graph) { - if (++upto < size) { - return nodeBuffer[upto]; - } - return NO_MORE_DOCS; - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 57e7e43d3d7..c8fd3f5d0a1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -30,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.hnsw.HnswUtil.Component; @@ -110,14 +111,7 @@ public class HnswGraphBuilder implements HnswBuilder { long seed, OnHeapHnswGraph hnsw) throws IOException { - this( - scorerSupplier, - M, - beamWidth, - seed, - hnsw, - null, - new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.size()))); + this(scorerSupplier, M, beamWidth, seed, hnsw, null); } /** @@ -138,8 +132,7 @@ public class HnswGraphBuilder implements HnswBuilder { int beamWidth, long seed, OnHeapHnswGraph hnsw, - HnswLock hnswLock, - HnswGraphSearcher graphSearcher) + HnswLock hnswLock) throws IOException { if (M <= 0) { throw new IllegalArgumentException("M (max connections) must be positive"); @@ -155,7 +148,12 @@ public class HnswGraphBuilder implements HnswBuilder { this.random = new SplittableRandom(seed); this.hnsw = hnsw; this.hnswLock = hnswLock; - this.graphSearcher = graphSearcher; + NeighborQueue neighborQueue = new NeighborQueue(beamWidth, true); + this.graphSearcher = + hnswLock != null + ? new ConcurrentSearcher(neighborQueue, hnswLock, new FixedBitSet(hnsw.maxNodeId() + 1)) + : new HnswGraphSearcher(neighborQueue, new FixedBitSet(hnsw.size())); + ; entryCandidates = new GraphBuilderKnnCollector(1); beamCandidates = new GraphBuilderKnnCollector(beamWidth); } @@ -608,4 +606,44 @@ public class HnswGraphBuilder implements HnswBuilder { throw new IllegalArgumentException(); } } + + /** + * This searcher will obtain the lock and make a copy of neighborArray when searching the graph so + * that concurrent modification of the graph will not impact the search + */ + private static class ConcurrentSearcher extends HnswGraphSearcher { + private final HnswLock hnswLock; + private int[] nodeBuffer; + private int upto; + private int size; + + private ConcurrentSearcher(NeighborQueue candidates, HnswLock hnswLock, BitSet visited) { + super(candidates, visited); + this.hnswLock = hnswLock; + } + + @Override + void graphSeek(HnswGraph graph, int level, int targetNode) { + Lock lock = hnswLock.read(level, targetNode); + try { + NeighborArray neighborArray = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); + if (nodeBuffer == null || nodeBuffer.length < neighborArray.size()) { + nodeBuffer = new int[neighborArray.size()]; + } + size = neighborArray.size(); + System.arraycopy(neighborArray.nodes(), 0, nodeBuffer, 0, size); + } finally { + lock.unlock(); + } + upto = -1; + } + + @Override + int graphNextNeighbor(HnswGraph graph) { + if (++upto < size) { + return nodeBuffer[upto]; + } + return NO_MORE_DOCS; + } + } }