Make TaskExecutor cx public and use TaskExecutor for concurrent HNSW graph build (#12799)

Make the TaskExecutor public which is currently pkg-private. At indexing time we concurrently create the hnsw graph (Concurrent HNSW Merge #12660). We could use the TaskExecutor implementation to do this for us.
Use TaskExecutor#invokeAll in HnswConcurrentMergeBuilder#build to run the workers concurrently.
This commit is contained in:
Shubham Chaudhary 2023-11-21 21:54:01 +05:30 committed by GitHub
parent 4309917215
commit 4628327af0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 52 additions and 55 deletions

View File

@ -168,6 +168,9 @@ API Changes
* GITHUB#12735: Remove FSTCompiler#getTermCount() and FSTCompiler.UnCompiledNode#inputCount (Anh Dung Bui)
* GITHUB#12799: Make TaskExecutor constructor public and use TaskExecutor for concurrent
HNSW graph build. (Shubham Chaudhary)
New Features
---------------------

View File

@ -31,6 +31,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
@ -60,7 +61,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
private final FlatVectorsFormat flatVectorsFormat;
private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final TaskExecutor mergeExec;
/** Constructs a format using default graph construction parameters */
public Lucene99HnswScalarQuantizedVectorsFormat() {
@ -121,7 +122,11 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
}

View File

@ -27,6 +27,7 @@ import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.HnswGraph;
@ -137,7 +138,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat();
private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final TaskExecutor mergeExec;
/** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() {
@ -192,7 +193,11 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}
@Override

View File

@ -23,7 +23,6 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
@ -35,6 +34,7 @@ import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
@ -67,7 +67,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private final int beamWidth;
private final FlatVectorsWriter flatVectorWriter;
private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final TaskExecutor mergeExec;
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
@ -78,7 +78,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
int beamWidth,
FlatVectorsWriter flatVectorWriter,
int numMergeWorkers,
ExecutorService mergeExec)
TaskExecutor mergeExec)
throws IOException {
this.M = M;
this.flatVectorWriter = flatVectorWriter;

View File

@ -53,7 +53,12 @@ public final class TaskExecutor {
private final Executor executor;
TaskExecutor(Executor executor) {
/**
* Creates a TaskExecutor instance
*
* @param executor the executor to be used for running tasks concurrently
*/
public TaskExecutor(Executor executor) {
this.executor = Objects.requireNonNull(executor, "Executor is null");
}

View File

@ -17,17 +17,17 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
private final ExecutorService exec;
private final TaskExecutor taskExecutor;
private final int numWorker;
/**
@ -38,10 +38,10 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
ExecutorService exec,
TaskExecutor taskExecutor,
int numWorker) {
super(fieldInfo, scorerSupplier, M, beamWidth);
this.exec = exec;
this.taskExecutor = taskExecutor;
this.numWorker = numWorker;
}
@ -50,7 +50,13 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
throws IOException {
if (initReader == null) {
return new HnswConcurrentMergeBuilder(
exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
taskExecutor,
numWorker,
scorerSupplier,
M,
beamWidth,
new OnHeapHnswGraph(M, maxOrd),
null);
}
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
@ -58,7 +64,7 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
return new HnswConcurrentMergeBuilder(
exec,
taskExecutor,
numWorker,
scorerSupplier,
M,

View File

@ -22,15 +22,12 @@ import static org.apache.lucene.util.hnsw.HnswGraphBuilder.HNSW_COMPONENT;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.ThreadInterruptedException;
/**
* A graph builder that manages multiple workers, it only supports adding the whole graph all at
@ -41,12 +38,12 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
private static final int DEFAULT_BATCH_SIZE =
2048; // number of vectors the worker handles sequentially at one batch
private final ExecutorService exec;
private final TaskExecutor taskExecutor;
private final ConcurrentMergeWorker[] workers;
private InfoStream infoStream = InfoStream.getDefault();
public HnswConcurrentMergeBuilder(
ExecutorService exec,
TaskExecutor taskExecutor,
int numWorker,
RandomVectorScorerSupplier scorerSupplier,
int M,
@ -54,7 +51,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
OnHeapHnswGraph hnsw,
BitSet initializedNodes)
throws IOException {
this.exec = exec;
this.taskExecutor = taskExecutor;
AtomicInteger workProgress = new AtomicInteger(0);
workers = new ConcurrentMergeWorker[numWorker];
for (int i = 0; i < numWorker; i++) {
@ -77,42 +74,16 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
HNSW_COMPONENT,
"build graph from " + maxOrd + " vectors, with " + workers.length + " workers");
}
List<Future<?>> futures = new ArrayList<>();
List<Callable<Void>> futures = new ArrayList<>();
for (int i = 0; i < workers.length; i++) {
int finalI = i;
futures.add(
exec.submit(
() -> {
try {
workers[finalI].run(maxOrd);
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
}
Throwable exc = null;
for (Future<?> future : futures) {
try {
future.get();
} catch (InterruptedException e) {
var newException = new ThreadInterruptedException(e);
if (exc == null) {
exc = newException;
} else {
exc.addSuppressed(newException);
}
} catch (ExecutionException e) {
if (exc == null) {
exc = e.getCause();
} else {
exc.addSuppressed(e.getCause());
}
}
}
if (exc != null) {
// The error handling was copied from TaskExecutor. should we just use TaskExecutor instead?
throw IOUtils.rethrowAlways(exc);
() -> {
workers[finalI].run(maxOrd);
return null;
});
}
taskExecutor.invokeAll(futures);
return workers[0].getGraph();
}

View File

@ -68,6 +68,7 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -1008,10 +1009,11 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
TaskExecutor taskExecutor = new TaskExecutor(exec);
HnswGraphBuilder.randSeed = random().nextLong();
HnswConcurrentMergeBuilder builder =
new HnswConcurrentMergeBuilder(
exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
taskExecutor, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
builder.setBatchSize(100);
builder.build(size);
exec.shutdownNow();