mirror of https://github.com/apache/lucene.git
Make TaskExecutor cx public and use TaskExecutor for concurrent HNSW graph build (#12799)
Make the TaskExecutor public which is currently pkg-private. At indexing time we concurrently create the hnsw graph (Concurrent HNSW Merge #12660). We could use the TaskExecutor implementation to do this for us. Use TaskExecutor#invokeAll in HnswConcurrentMergeBuilder#build to run the workers concurrently.
This commit is contained in:
parent
4309917215
commit
4628327af0
|
@ -168,6 +168,9 @@ API Changes
|
|||
|
||||
* GITHUB#12735: Remove FSTCompiler#getTermCount() and FSTCompiler.UnCompiledNode#inputCount (Anh Dung Bui)
|
||||
|
||||
* GITHUB#12799: Make TaskExecutor constructor public and use TaskExecutor for concurrent
|
||||
HNSW graph build. (Shubham Chaudhary)
|
||||
|
||||
New Features
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
||||
/**
|
||||
|
@ -60,7 +61,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
private final FlatVectorsFormat flatVectorsFormat;
|
||||
|
||||
private final int numMergeWorkers;
|
||||
private final ExecutorService mergeExec;
|
||||
private final TaskExecutor mergeExec;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99HnswScalarQuantizedVectorsFormat() {
|
||||
|
@ -121,7 +122,11 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
|
|||
"No executor service is needed as we'll use single thread to merge");
|
||||
}
|
||||
this.numMergeWorkers = numMergeWorkers;
|
||||
this.mergeExec = mergeExec;
|
||||
if (mergeExec != null) {
|
||||
this.mergeExec = new TaskExecutor(mergeExec);
|
||||
} else {
|
||||
this.mergeExec = null;
|
||||
}
|
||||
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
||||
|
@ -137,7 +138,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
|||
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat();
|
||||
|
||||
private final int numMergeWorkers;
|
||||
private final ExecutorService mergeExec;
|
||||
private final TaskExecutor mergeExec;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99HnswVectorsFormat() {
|
||||
|
@ -192,7 +193,11 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
|||
"No executor service is needed as we'll use single thread to merge");
|
||||
}
|
||||
this.numMergeWorkers = numMergeWorkers;
|
||||
this.mergeExec = mergeExec;
|
||||
if (mergeExec != null) {
|
||||
this.mergeExec = new TaskExecutor(mergeExec);
|
||||
} else {
|
||||
this.mergeExec = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -23,7 +23,6 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.FlatVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
|
@ -35,6 +34,7 @@ import org.apache.lucene.index.MergeState;
|
|||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
@ -67,7 +67,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final int beamWidth;
|
||||
private final FlatVectorsWriter flatVectorWriter;
|
||||
private final int numMergeWorkers;
|
||||
private final ExecutorService mergeExec;
|
||||
private final TaskExecutor mergeExec;
|
||||
|
||||
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||
private boolean finished;
|
||||
|
@ -78,7 +78,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int beamWidth,
|
||||
FlatVectorsWriter flatVectorWriter,
|
||||
int numMergeWorkers,
|
||||
ExecutorService mergeExec)
|
||||
TaskExecutor mergeExec)
|
||||
throws IOException {
|
||||
this.M = M;
|
||||
this.flatVectorWriter = flatVectorWriter;
|
||||
|
|
|
@ -53,7 +53,12 @@ public final class TaskExecutor {
|
|||
|
||||
private final Executor executor;
|
||||
|
||||
TaskExecutor(Executor executor) {
|
||||
/**
|
||||
* Creates a TaskExecutor instance
|
||||
*
|
||||
* @param executor the executor to be used for running tasks concurrently
|
||||
*/
|
||||
public TaskExecutor(Executor executor) {
|
||||
this.executor = Objects.requireNonNull(executor, "Executor is null");
|
||||
}
|
||||
|
||||
|
|
|
@ -17,17 +17,17 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
|
||||
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
||||
|
||||
private final ExecutorService exec;
|
||||
private final TaskExecutor taskExecutor;
|
||||
private final int numWorker;
|
||||
|
||||
/**
|
||||
|
@ -38,10 +38,10 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
|||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
ExecutorService exec,
|
||||
TaskExecutor taskExecutor,
|
||||
int numWorker) {
|
||||
super(fieldInfo, scorerSupplier, M, beamWidth);
|
||||
this.exec = exec;
|
||||
this.taskExecutor = taskExecutor;
|
||||
this.numWorker = numWorker;
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,13 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
|||
throws IOException {
|
||||
if (initReader == null) {
|
||||
return new HnswConcurrentMergeBuilder(
|
||||
exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
|
||||
taskExecutor,
|
||||
numWorker,
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
new OnHeapHnswGraph(M, maxOrd),
|
||||
null);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
|
||||
|
@ -58,7 +64,7 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
|||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
|
||||
|
||||
return new HnswConcurrentMergeBuilder(
|
||||
exec,
|
||||
taskExecutor,
|
||||
numWorker,
|
||||
scorerSupplier,
|
||||
M,
|
||||
|
|
|
@ -22,15 +22,12 @@ import static org.apache.lucene.util.hnsw.HnswGraphBuilder.HNSW_COMPONENT;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.ThreadInterruptedException;
|
||||
|
||||
/**
|
||||
* A graph builder that manages multiple workers, it only supports adding the whole graph all at
|
||||
|
@ -41,12 +38,12 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
|||
private static final int DEFAULT_BATCH_SIZE =
|
||||
2048; // number of vectors the worker handles sequentially at one batch
|
||||
|
||||
private final ExecutorService exec;
|
||||
private final TaskExecutor taskExecutor;
|
||||
private final ConcurrentMergeWorker[] workers;
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
||||
public HnswConcurrentMergeBuilder(
|
||||
ExecutorService exec,
|
||||
TaskExecutor taskExecutor,
|
||||
int numWorker,
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
|
@ -54,7 +51,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
|||
OnHeapHnswGraph hnsw,
|
||||
BitSet initializedNodes)
|
||||
throws IOException {
|
||||
this.exec = exec;
|
||||
this.taskExecutor = taskExecutor;
|
||||
AtomicInteger workProgress = new AtomicInteger(0);
|
||||
workers = new ConcurrentMergeWorker[numWorker];
|
||||
for (int i = 0; i < numWorker; i++) {
|
||||
|
@ -77,42 +74,16 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
|||
HNSW_COMPONENT,
|
||||
"build graph from " + maxOrd + " vectors, with " + workers.length + " workers");
|
||||
}
|
||||
List<Future<?>> futures = new ArrayList<>();
|
||||
List<Callable<Void>> futures = new ArrayList<>();
|
||||
for (int i = 0; i < workers.length; i++) {
|
||||
int finalI = i;
|
||||
futures.add(
|
||||
exec.submit(
|
||||
() -> {
|
||||
try {
|
||||
workers[finalI].run(maxOrd);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
Throwable exc = null;
|
||||
for (Future<?> future : futures) {
|
||||
try {
|
||||
future.get();
|
||||
} catch (InterruptedException e) {
|
||||
var newException = new ThreadInterruptedException(e);
|
||||
if (exc == null) {
|
||||
exc = newException;
|
||||
} else {
|
||||
exc.addSuppressed(newException);
|
||||
}
|
||||
} catch (ExecutionException e) {
|
||||
if (exc == null) {
|
||||
exc = e.getCause();
|
||||
} else {
|
||||
exc.addSuppressed(e.getCause());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (exc != null) {
|
||||
// The error handling was copied from TaskExecutor. should we just use TaskExecutor instead?
|
||||
throw IOUtils.rethrowAlways(exc);
|
||||
() -> {
|
||||
workers[finalI].run(maxOrd);
|
||||
return null;
|
||||
});
|
||||
}
|
||||
taskExecutor.invokeAll(futures);
|
||||
return workers[0].getGraph();
|
||||
}
|
||||
|
||||
|
|
|
@ -68,6 +68,7 @@ import org.apache.lucene.search.Query;
|
|||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
@ -1008,10 +1009,11 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
|
||||
TaskExecutor taskExecutor = new TaskExecutor(exec);
|
||||
HnswGraphBuilder.randSeed = random().nextLong();
|
||||
HnswConcurrentMergeBuilder builder =
|
||||
new HnswConcurrentMergeBuilder(
|
||||
exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
|
||||
taskExecutor, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
|
||||
builder.setBatchSize(100);
|
||||
builder.build(size);
|
||||
exec.shutdownNow();
|
||||
|
|
Loading…
Reference in New Issue