Concurrent HNSW Merge (#12660)

This commit is contained in:
Patrick Zhai 2023-10-28 11:03:22 -07:00 committed by GitHub
parent f5776c8844
commit a8c52e2e19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 904 additions and 161 deletions

View File

@ -169,6 +169,9 @@ New Features
* GITHUB#12582: Add int8 scalar quantization to the HNSW vector format. This optionally allows for more compact lossy
storage for the vectors, requiring about 75% memory for fast HNSW search. (Ben Trent)
* GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat.
(Patrick Zhai)
Improvements
---------------------

View File

@ -476,9 +476,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(docsWithField.cardinality());
graph =
merger.merge(
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;

View File

@ -18,6 +18,7 @@
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -151,6 +152,9 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
*/
public static final int DEFAULT_BEAM_WIDTH = 100;
/** Default to use single thread merge */
public static final int DEFAULT_NUM_MERGE_WORKER = 1;
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
/**
@ -169,20 +173,36 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
/** Should this codec scalar quantize float32 vectors and use this format */
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;
private final int numMergeWorkers;
private final ExecutorService mergeExec;
/** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
}
public Lucene99HnswVectorsFormat(
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
this(maxConn, beamWidth, scalarQuantize, DEFAULT_NUM_MERGE_WORKER, null);
}
/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param scalarQuantize the scalar quantization format
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge
*/
public Lucene99HnswVectorsFormat(
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
int maxConn,
int beamWidth,
Lucene99ScalarQuantizedVectorsFormat scalarQuantize,
int numMergeWorkers,
ExecutorService mergeExec) {
super("Lucene99HnswVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
@ -198,14 +218,25 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
+ "; beamWidth="
+ beamWidth);
}
if (numMergeWorkers > 1 && mergeExec == null) {
throw new IllegalArgumentException(
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
}
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException(
"No executor service is needed as we'll use single thread to merge");
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.scalarQuantizedVectorsFormat = scalarQuantize;
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, scalarQuantizedVectorsFormat);
return new Lucene99HnswVectorsWriter(
state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec);
}
@Override

View File

@ -28,6 +28,7 @@ import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -52,9 +53,11 @@ import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphMerger;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
@ -75,6 +78,8 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private final int M;
private final int beamWidth;
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
@ -83,10 +88,14 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
SegmentWriteState state,
int M,
int beamWidth,
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat)
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat,
int numMergeWorkers,
ExecutorService mergeExec)
throws IOException {
this.M = M;
this.beamWidth = beamWidth;
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
@ -383,7 +392,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
int node = nodesOnLevel0.nextInt();
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
long offset = vectorIndex.getFilePointer();
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
reconstructAndWriteNeighbours(neighbors, oldToNewMap, maxOrd);
levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offset);
}
@ -400,7 +409,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
for (int node : newNodes) {
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
long offset = vectorIndex.getFilePointer();
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
reconstructAndWriteNeighbours(neighbors, oldToNewMap, maxOrd);
levelNodeOffsets[level][nodeOffsetIndex++] =
Math.toIntExact(vectorIndex.getFilePointer() - offset);
}
@ -442,7 +451,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
};
}
private void reconstructAndWriteNeigbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
private void reconstructAndWriteNeighbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
throws IOException {
int size = neighbors.size();
vectorIndex.writeVInt(size);
@ -557,6 +566,12 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
IOUtils.close(finalVectorDataInput);
segmentWriteState.directory.deleteFile(tempFileName);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
// here we just return the inner out since we only need to close this outside copy
return innerScoreSupplier.copy();
}
};
} else {
// No need to use temporary file as we don't have to re-open for reading
@ -579,8 +594,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
// build graph
IncrementalHnswGraphMerger merger =
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
@ -592,9 +606,9 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(docsWithField.cardinality());
graph =
merger.merge(
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@ -675,6 +689,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
return sortedNodes;
}
private HnswGraphMerger createGraphMerger(
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
if (mergeExec != null) {
return new ConcurrentHnswMerger(
fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
}
return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
}
private void writeMeta(
boolean isQuantized,
FieldInfo field,
@ -819,6 +842,9 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override
public void close() throws IOException {
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
if (mergeExec != null) {
mergeExec.shutdownNow();
}
}
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {

View File

@ -52,6 +52,7 @@ import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/**
* Writes quantized vector values and metadata to index segments.
@ -761,6 +762,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
return supplier.scorer(ord);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return supplier.copy();
}
@Override
public void close() throws IOException {
onClose.close();

View File

@ -39,6 +39,12 @@ final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSco
this.values = values;
}
private ScalarQuantizedRandomVectorScorerSupplier(
ScalarQuantizedVectorSimilarity similarity, RandomAccessQuantizedByteVectorValues values) {
this.similarity = similarity;
this.values = values;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
@ -46,4 +52,9 @@ final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSco
final float queryOffset = values.getScoreCorrectionConstant();
return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ScalarQuantizedRandomVectorScorerSupplier(similarity, values.copy());
}
}

View File

@ -22,6 +22,9 @@ import java.io.Closeable;
/**
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
* close after use
*
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
* closeable
*/
public interface CloseableRandomVectorScorerSupplier
extends Closeable, RandomVectorScorerSupplier {}

View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
private final ExecutorService exec;
private final int numWorker;
/**
* @param fieldInfo FieldInfo for the field being merged
*/
public ConcurrentHnswMerger(
FieldInfo fieldInfo,
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
ExecutorService exec,
int numWorker) {
super(fieldInfo, scorerSupplier, M, beamWidth);
this.exec = exec;
this.numWorker = numWorker;
}
@Override
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
throws IOException {
if (initReader == null) {
return new HnswConcurrentMergeBuilder(
exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
}
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
return new HnswConcurrentMergeBuilder(
exec,
numWorker,
scorerSupplier,
M,
beamWidth,
InitializedHnswGraphBuilder.initGraph(M, initializerGraph, oldToNewOrdinalMap, maxOrd),
initializedNodes);
}
}

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.util.InfoStream;
/**
* Interface for builder building the {@link OnHeapHnswGraph}
*
* @lucene.experimental
*/
public interface HnswBuilder {
/**
* Adds all nodes to the graph up to the provided {@code maxOrd}.
*
* @param maxOrd The maximum ordinal (excluded) of the nodes to be added.
*/
OnHeapHnswGraph build(int maxOrd) throws IOException;
/** Inserts a doc with vector value to the graph */
void addGraphNode(int node) throws IOException;
/** Set info-stream to output debugging information */
void setInfoStream(InfoStream infoStream);
OnHeapHnswGraph getGraph();
}

View File

@ -0,0 +1,248 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.HNSW_COMPONENT;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.ThreadInterruptedException;
/**
* A graph builder that manages multiple workers, it only supports adding the whole graph all at
* once. It will spawn a thread for each worker and the workers will pick the work in batches.
*/
public class HnswConcurrentMergeBuilder implements HnswBuilder {
private static final int DEFAULT_BATCH_SIZE =
2048; // number of vectors the worker handles sequentially at one batch
private final ExecutorService exec;
private final ConcurrentMergeWorker[] workers;
private InfoStream infoStream = InfoStream.getDefault();
public HnswConcurrentMergeBuilder(
ExecutorService exec,
int numWorker,
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
OnHeapHnswGraph hnsw,
BitSet initializedNodes)
throws IOException {
this.exec = exec;
AtomicInteger workProgress = new AtomicInteger(0);
workers = new ConcurrentMergeWorker[numWorker];
for (int i = 0; i < numWorker; i++) {
workers[i] =
new ConcurrentMergeWorker(
scorerSupplier.copy(),
M,
beamWidth,
HnswGraphBuilder.randSeed,
hnsw,
initializedNodes,
workProgress);
}
}
@Override
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(
HNSW_COMPONENT,
"build graph from " + maxOrd + " vectors, with " + workers.length + " workers");
}
List<Future<?>> futures = new ArrayList<>();
for (int i = 0; i < workers.length; i++) {
int finalI = i;
futures.add(
exec.submit(
() -> {
try {
workers[finalI].run(maxOrd);
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
}
Throwable exc = null;
for (Future<?> future : futures) {
try {
future.get();
} catch (InterruptedException e) {
var newException = new ThreadInterruptedException(e);
if (exc == null) {
exc = newException;
} else {
exc.addSuppressed(newException);
}
} catch (ExecutionException e) {
if (exc == null) {
exc = e.getCause();
} else {
exc.addSuppressed(e.getCause());
}
}
}
if (exc != null) {
// The error handling was copied from TaskExecutor. should we just use TaskExecutor instead?
throw IOUtils.rethrowAlways(exc);
}
return workers[0].getGraph();
}
@Override
public void addGraphNode(int node) throws IOException {
throw new UnsupportedOperationException("This builder is for merge only");
}
@Override
public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream;
for (HnswBuilder worker : workers) {
worker.setInfoStream(infoStream);
}
}
@Override
public OnHeapHnswGraph getGraph() {
return workers[0].getGraph();
}
/* test only for now */
void setBatchSize(int newSize) {
for (ConcurrentMergeWorker worker : workers) {
worker.batchSize = newSize;
}
}
private static final class ConcurrentMergeWorker extends HnswGraphBuilder {
/**
* A common AtomicInteger shared among all workers, used for tracking what's the next vector to
* be added to the graph.
*/
private final AtomicInteger workProgress;
private final BitSet initializedNodes;
private int batchSize = DEFAULT_BATCH_SIZE;
private ConcurrentMergeWorker(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw,
BitSet initializedNodes,
AtomicInteger workProgress)
throws IOException {
super(
scorerSupplier,
M,
beamWidth,
seed,
hnsw,
new MergeSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.maxNodeId() + 1)));
this.workProgress = workProgress;
this.initializedNodes = initializedNodes;
}
/**
* This method first try to "reserve" part of work by calling {@link #getStartPos(int)} and then
* calling {@link #addVectors(int, int)} to actually add the nodes to the graph. By doing this
* we are able to dynamically allocate the work to multiple workers and try to make all of them
* finishing around the same time.
*/
private void run(int maxOrd) throws IOException {
int start = getStartPos(maxOrd);
int end;
while (start != -1) {
end = Math.min(maxOrd, start + batchSize);
addVectors(start, end);
start = getStartPos(maxOrd);
}
}
/** Reserve the work by atomically increment the {@link #workProgress} */
private int getStartPos(int maxOrd) {
int start = workProgress.getAndAdd(batchSize);
if (start < maxOrd) {
return start;
} else {
return -1;
}
}
@Override
public void addGraphNode(int node) throws IOException {
if (initializedNodes != null && initializedNodes.get(node)) {
return;
}
super.addGraphNode(node);
}
}
/**
* This searcher will obtain the lock and make a copy of neighborArray when seeking the graph such
* that concurrent modification of the graph will not impact the search
*/
private static class MergeSearcher extends HnswGraphSearcher {
private int[] nodeBuffer;
private int upto;
private int size;
private MergeSearcher(NeighborQueue candidates, BitSet visited) {
super(candidates, visited);
}
@Override
void graphSeek(HnswGraph graph, int level, int targetNode) {
NeighborArray neighborArray = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
neighborArray.rwlock.readLock().lock();
try {
if (nodeBuffer == null || nodeBuffer.length < neighborArray.size()) {
nodeBuffer = new int[neighborArray.size()];
}
size = neighborArray.size();
if (size >= 0) System.arraycopy(neighborArray.node, 0, nodeBuffer, 0, size);
} finally {
neighborArray.rwlock.readLock().unlock();
}
upto = -1;
}
@Override
int graphNextNeighbor(HnswGraph graph) {
if (++upto < size) {
return nodeBuffer[upto];
}
return NO_MORE_DOCS;
}
}
}

View File

@ -33,7 +33,7 @@ import org.apache.lucene.util.InfoStream;
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyper-parameters.
*/
public class HnswGraphBuilder {
public class HnswGraphBuilder implements HnswBuilder {
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
@ -54,7 +54,6 @@ public class HnswGraphBuilder {
private final int M; // max number of connections on upper layers
private final double ml;
private final NeighborArray scratch;
private final SplittableRandom random;
private final RandomVectorScorerSupplier scorerSupplier;
@ -97,6 +96,22 @@ public class HnswGraphBuilder {
this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
}
protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw)
throws IOException {
this(
scorerSupplier,
M,
beamWidth,
seed,
hnsw,
new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.size())));
}
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
@ -114,7 +129,8 @@ public class HnswGraphBuilder {
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw)
OnHeapHnswGraph hnsw,
HnswGraphSearcher graphSearcher)
throws IOException {
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
@ -129,20 +145,12 @@ public class HnswGraphBuilder {
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
this.hnsw = hnsw;
this.graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
// in scratch we store candidates in reverse order: worse candidates are first
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
this.graphSearcher = graphSearcher;
entryCandidates = new GraphBuilderKnnCollector(1);
beamCandidates = new GraphBuilderKnnCollector(beamWidth);
}
/**
* Adds all nodes to the graph up to the provided {@code maxOrd}.
*
* @param maxOrd The maximum ordinal of the nodes to be added.
*/
@Override
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
@ -151,18 +159,23 @@ public class HnswGraphBuilder {
return hnsw;
}
/** Set info-stream to output debugging information * */
@Override
public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream;
}
@Override
public OnHeapHnswGraph getGraph() {
return hnsw;
}
private void addVectors(int maxOrd) throws IOException {
/** add vectors in range [minOrd, maxOrd) */
protected void addVectors(int minOrd, int maxOrd) throws IOException {
long start = System.nanoTime(), t = start;
for (int node = 0; node < maxOrd; node++) {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
}
for (int node = minOrd; node < maxOrd; node++) {
addGraphNode(node);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
@ -170,42 +183,98 @@ public class HnswGraphBuilder {
}
}
/** Inserts a doc with vector value to the graph */
private void addVectors(int maxOrd) throws IOException {
addVectors(0, maxOrd);
}
@Override
public void addGraphNode(int node) throws IOException {
/*
Note: this implementation is thread safe when graph size is fixed (e.g. when merging)
The process of adding a node is roughly:
1. Add the node to all level from top to the bottom, but do not connect it to any other node,
nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty
and this is the first node, in that case we set the entry node and return)
2. Do the search from top to bottom, remember all the possible neighbours on each level the node
is on.
3. Add the neighbor to the node from bottom to top level, when adding the neighbour,
we always add all the outgoing links first before adding incoming link such that
when a search visits this node, it can always find a way out
4. If the node has level that is less or equal to graph level, then we're done here.
If the node has level larger than graph level, then we need to promote the node
as the entry node. If, while we add the node to the graph, the entry node has changed
(which means the graph level has changed as well), we need to reinsert the node
to the newly introduced levels (repeating step 2,3 for new levels) and again try to
promote the node to entry node.
*/
RandomVectorScorer scorer = scorerSupplier.scorer(node);
final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1;
// If entrynode is -1, then this should finish without adding neighbors
if (hnsw.entryNode() == -1) {
for (int level = nodeLevel; level >= 0; level--) {
hnsw.addNode(level, node);
}
// first add nodes to all levels
for (int level = nodeLevel; level >= 0; level--) {
hnsw.addNode(level, node);
}
// then promote itself as entry node if entry node is not set
if (hnsw.trySetNewEntryNode(node, nodeLevel)) {
return;
}
int[] eps = new int[] {hnsw.entryNode()};
// if the entry node is already set, then we have to do all connections first before we can
// promote ourselves as entry node
// if a node introduces new levels to the graph, add this new node on new levels
for (int level = nodeLevel; level > curMaxLevel; level--) {
hnsw.addNode(level, node);
}
int lowestUnsetLevel = 0;
int curMaxLevel;
do {
curMaxLevel = hnsw.numLevels() - 1;
// NOTE: the entry node and max level may not be paired, but because we get the level first
// we ensure that the entry node we get later will always exist on the curMaxLevel
int[] eps = new int[] {hnsw.entryNode()};
// for levels > nodeLevel search with topk = 1
GraphBuilderKnnCollector candidates = entryCandidates;
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates.clear();
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
eps = new int[] {candidates.popNode()};
}
// for levels <= nodeLevel search with topk = beamWidth, and add connections
candidates = beamCandidates;
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
candidates.clear();
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
eps = candidates.popUntilNearestKNodes();
hnsw.addNode(level, node);
addDiverseNeighbors(level, node, candidates);
}
// we first do the search from top to bottom
// for levels > nodeLevel search with topk = 1
GraphBuilderKnnCollector candidates = entryCandidates;
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates.clear();
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
eps[0] = candidates.popNode();
}
// for levels <= nodeLevel search with topk = beamWidth, and add connections
candidates = beamCandidates;
NeighborArray[] scratchPerLevel =
new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1];
for (int i = scratchPerLevel.length - 1; i >= 0; i--) {
int level = i + lowestUnsetLevel;
candidates.clear();
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
eps = candidates.popUntilNearestKNodes();
scratchPerLevel[i] = new NeighborArray(Math.max(beamCandidates.k(), M + 1), false);
popToScratch(candidates, scratchPerLevel[i]);
}
// then do connections from bottom up
for (int i = 0; i < scratchPerLevel.length; i++) {
addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]);
}
lowestUnsetLevel += scratchPerLevel.length;
assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
if (lowestUnsetLevel > nodeLevel) {
return;
}
assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel;
if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) {
return;
}
if (hnsw.numLevels() == curMaxLevel + 1) {
// This should never happen if all the calculations are correct
throw new IllegalStateException(
"We're not able to promote node "
+ node
+ " at level "
+ nodeLevel
+ " as entry node. But the max graph level "
+ curMaxLevel
+ " has not changed while we are inserting the node.");
}
} while (true);
}
private long printGraphBuildStatus(int node, long start, long t) {
@ -221,7 +290,7 @@ public class HnswGraphBuilder {
return now;
}
private void addDiverseNeighbors(int level, int node, GraphBuilderKnnCollector candidates)
private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
@ -229,26 +298,40 @@ public class HnswGraphBuilder {
*/
NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node
popToScratch(candidates);
int maxConnOnLevel = level == 0 ? M * 2 : M;
selectAndLinkDiverse(neighbors, scratch, maxConnOnLevel);
boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
// NOTE: here we're using candidates and mask but not the neighbour array because once we have
// added incoming link there will be possibilities of this node being discovered and neighbour
// array being modified. So using local candidates and mask is a safer option.
for (int i = 0; i < candidates.size(); i++) {
if (mask[i] == false) {
continue;
}
int nbr = candidates.node[i];
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
if (nbrsOfNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
nbrsOfNbr.removeIndex(indexToRemove);
nbrsOfNbr.rwlock.writeLock().lock();
try {
nbrsOfNbr.addOutOfOrder(node, candidates.score[i]);
if (nbrsOfNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
nbrsOfNbr.removeIndex(indexToRemove);
}
} finally {
nbrsOfNbr.rwlock.writeLock().unlock();
}
}
}
private void selectAndLinkDiverse(
/**
* This method will select neighbors to add and return a mask telling the caller which candidates
* are selected
*/
private boolean[] selectAndLinkDiverse(
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
boolean[] mask = new boolean[candidates.size()];
// Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
@ -257,12 +340,16 @@ public class HnswGraphBuilder {
float cScore = candidates.score[i];
assert cNode <= hnsw.maxNodeId();
if (diversityCheck(cNode, cScore, neighbors)) {
mask[i] = true;
// here we don't need to lock, because there's no incoming link so no others is able to
// discover this node such that no others will modify this neighbor array as well
neighbors.addInOrder(cNode, cScore);
}
}
return mask;
}
private void popToScratch(GraphBuilderKnnCollector candidates) {
private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) {
scratch.clear();
int candidateCount = candidates.size();
// extract all the Neighbors from the queue into an array; these will now be

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.InfoStream;
/**
* Abstraction of merging multiple graphs into one on-heap graph
*
* @lucene.experimental
*/
public interface HnswGraphMerger {
/**
* Adds a reader to the graph merger to record the state
*
* @param reader KnnVectorsReader to add to the merger
* @param docMap MergeState.DocMap for the reader
* @param liveDocs Bits representing live docs, can be null
* @return this
* @throws IOException If an error occurs while reading from the merge state
*/
HnswGraphMerger addReader(KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs)
throws IOException;
/**
* Merge and produce the on heap graph
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @param infoStream optional info stream to set to builder
* @param maxOrd max number of vectors that will be added to the graph
* @return merged graph
* @throws IOException during merge
*/
OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd)
throws IOException;
}

View File

@ -32,6 +32,7 @@ import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
/**
* This selects the biggest Hnsw graph from the provided merge state and initializes a new
@ -39,15 +40,16 @@ import org.apache.lucene.util.FixedBitSet;
*
* @lucene.experimental
*/
public class IncrementalHnswGraphMerger {
public class IncrementalHnswGraphMerger implements HnswGraphMerger {
private KnnVectorsReader initReader;
private MergeState.DocMap initDocMap;
private int initGraphSize;
private final FieldInfo fieldInfo;
private final RandomVectorScorerSupplier scorerSupplier;
private final int M;
private final int beamWidth;
protected final FieldInfo fieldInfo;
protected final RandomVectorScorerSupplier scorerSupplier;
protected final int M;
protected final int beamWidth;
protected KnnVectorsReader initReader;
protected MergeState.DocMap initDocMap;
protected int initGraphSize;
/**
* @param fieldInfo FieldInfo for the field being merged
@ -64,13 +66,8 @@ public class IncrementalHnswGraphMerger {
* Adds a reader to the graph merger if it meets the following criteria: 1. Does not contain any
* deleted docs 2. Is a HnswGraphProvider/PerFieldKnnVectorReader 3. Has the most docs of any
* previous reader that met the above criteria
*
* @param reader KnnVectorsReader to add to the merger
* @param docMap MergeState.DocMap for the reader
* @param liveDocs Bits representing live docs, can be null
* @return this
* @throws IOException If an error occurs while reading from the merge state
*/
@Override
public IncrementalHnswGraphMerger addReader(
KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException {
KnnVectorsReader currKnnVectorsReader = reader;
@ -113,18 +110,20 @@ public class IncrementalHnswGraphMerger {
* If no valid readers were added to the merge state, a new graph is created.
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @param maxOrd max num of vectors that will be merged into the graph
* @return HnswGraphBuilder
* @throws IOException If an error occurs while reading from the merge state
*/
public HnswGraphBuilder createBuilder(DocIdSetIterator mergedVectorIterator) throws IOException {
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
throws IOException {
if (initReader == null) {
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, maxOrd);
}
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
final int numVectors = Math.toIntExact(mergedVectorIterator.cost());
BitSet initializedNodes = new FixedBitSet(numVectors + 1);
BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
return InitializedHnswGraphBuilder.fromGraph(
scorerSupplier,
@ -134,7 +133,15 @@ public class IncrementalHnswGraphMerger {
initializerGraph,
oldToNewOrdinalMap,
initializedNodes,
numVectors);
maxOrd);
}
@Override
public OnHeapHnswGraph merge(
DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException {
HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd);
builder.setInfoStream(infoStream);
return builder.build(maxOrd);
}
/**
@ -146,8 +153,8 @@ public class IncrementalHnswGraphMerger {
* @return the mapping from old ordinals to new ordinals
* @throws IOException If an error occurs while reading from the merge state
*/
private int[] getNewOrdMapping(DocIdSetIterator mergedVectorIterator, BitSet initializedNodes)
throws IOException {
protected final int[] getNewOrdMapping(
DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException {
DocIdSetIterator initializerIterator = null;
switch (fieldInfo.getVectorEncoding()) {

View File

@ -55,6 +55,18 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
BitSet initializedNodes,
int totalNumberOfVectors)
throws IOException {
return new InitializedHnswGraphBuilder(
scorerSupplier,
M,
beamWidth,
seed,
initGraph(M, initializerGraph, newOrdMap, totalNumberOfVectors),
initializedNodes);
}
public static OnHeapHnswGraph initGraph(
int M, HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors)
throws IOException {
OnHeapHnswGraph hnsw = new OnHeapHnswGraph(M, totalNumberOfVectors);
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
@ -62,6 +74,7 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
int oldOrd = it.nextInt();
int newOrd = newOrdMap[oldOrd];
hnsw.addNode(level, newOrd);
hnsw.trySetNewEntryNode(newOrd, level);
NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd);
initializerGraph.seek(level, oldOrd);
for (int oldNeighbor = initializerGraph.nextNeighbor();
@ -73,8 +86,7 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
}
}
}
return new InitializedHnswGraphBuilder(
scorerSupplier, M, beamWidth, seed, hnsw, initializedNodes);
return hnsw;
}
private final BitSet initializedNodes;

View File

@ -19,6 +19,8 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.lucene.util.ArrayUtil;
/**
@ -35,6 +37,7 @@ public class NeighborArray {
float[] score;
int[] node;
private int sortedNodeSize;
public final ReadWriteLock rwlock = new ReentrantReadWriteLock(true);
public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];

View File

@ -21,6 +21,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
@ -33,8 +35,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
private static final int INIT_SIZE = 128;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level. -1 if not set
private final AtomicReference<EntryNode> entryNode;
// the internal graph representation where the first dimension is node id and second dimension is
// level
@ -47,11 +48,13 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
private int
lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
// levelToNodes
private int size; // graph size, which is number of nodes in level 0
private int
nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
private final AtomicInteger size =
new AtomicInteger(0); // graph size, which is number of nodes in level 0
private final AtomicInteger nonZeroLevelSize =
new AtomicInteger(
0); // total number of NeighborArrays created that is not on level 0, for now it
// is only used to account memory usage
private int maxNodeId;
private final AtomicInteger maxNodeId = new AtomicInteger(-1);
private final int nsize; // neighbour array size at non-zero level
private final int nsize0; // neighbour array size at zero level
private final boolean
@ -69,11 +72,9 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
* growing itself (you cannot add a node with has id >= numNodes)
*/
OnHeapHnswGraph(int M, int numNodes) {
this.numLevels = 1; // Implicitly start the graph with a single level
this.entryNode = -1; // Entry node should be negative until a node is added
this.entryNode = new AtomicReference<>(new EntryNode(-1, 1));
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
this.maxNodeId = -1;
this.nsize = M + 1;
this.nsize0 = (M * 2 + 1);
noGrowth = numNodes != -1;
@ -96,7 +97,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
@Override
public int size() {
return size;
return size.get();
}
/**
@ -107,7 +108,16 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
*/
@Override
public int maxNodeId() {
return maxNodeId;
if (noGrowth) {
// we know the eventual graph size and the graph can possibly
// being concurrently modified
return graph.length - 1;
} else {
// The graph cannot be concurrently modified (and searched) if
// we don't know the size beforehand, so it's safe to return the
// actual maxNodeId
return maxNodeId.get();
}
}
/**
@ -120,9 +130,6 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
* @param node the node to add, represented as an ordinal on the level 0.
*/
public void addNode(int level, int node) {
if (entryNode == -1) {
entryNode = node;
}
if (node >= graph.length) {
if (noGrowth) {
@ -132,25 +139,20 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
graph = ArrayUtil.grow(graph, node + 1);
}
if (level >= numLevels) {
numLevels = level + 1;
entryNode = node;
}
assert graph[node] == null || graph[node].length > level
: "node must be inserted from the top level";
if (graph[node] == null) {
graph[node] =
new NeighborArray[level + 1]; // assumption: we always call this function from top level
size++;
size.incrementAndGet();
}
if (level == 0) {
graph[node][level] = new NeighborArray(nsize0, true);
} else {
graph[node][level] = new NeighborArray(nsize, true);
nonZeroLevelSize++;
nonZeroLevelSize.incrementAndGet();
}
maxNodeId = Math.max(maxNodeId, node);
maxNodeId.accumulateAndGet(node, Math::max);
}
@Override
@ -174,7 +176,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
*/
@Override
public int numLevels() {
return numLevels;
return entryNode.get().level + 1;
}
/**
@ -185,7 +187,41 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
*/
@Override
public int entryNode() {
return entryNode;
return entryNode.get().node;
}
/**
* Try to set the entry node if the graph does not have one
*
* @return True if the entry node is set to the provided node. False if the entry node already
* exists
*/
public boolean trySetNewEntryNode(int node, int level) {
EntryNode current = entryNode.get();
if (current.node == -1) {
return entryNode.compareAndSet(current, new EntryNode(node, level));
}
return false;
}
/**
* Try to promote the provided node to the entry node
*
* @param level should be larger than expectedOldLevel
* @param expectOldLevel is the old entry node level the caller expect to be, the actual graph
* level can be different due to concurrent modification
* @return True if the entry node is set to the provided node. False if expectOldLevel is not the
* same as the current entry node level. Even if the provided node's level is still higher
* than the current entry node level, the new entry node will not be set and false will be
* returned.
*/
public boolean tryPromoteNewEntryNode(int node, int level, int expectOldLevel) {
assert level > expectOldLevel;
EntryNode currentEntry = entryNode.get();
if (currentEntry.level == expectOldLevel) {
return entryNode.compareAndSet(currentEntry, new EntryNode(node, level));
}
return false;
}
/**
@ -212,12 +248,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
@SuppressWarnings({"unchecked", "rawtypes"})
private void generateLevelToNodes() {
if (lastFreezeSize == size) {
if (lastFreezeSize == size()) {
return;
}
levelToNodes = new List[numLevels];
for (int i = 1; i < numLevels; i++) {
int maxLevels = numLevels();
levelToNodes = new List[maxLevels];
for (int i = 1; i < maxLevels; i++) {
levelToNodes[i] = new ArrayList<>();
}
int nonNullNode = 0;
@ -230,38 +266,44 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
for (int i = 1; i < graph[node].length; i++) {
levelToNodes[i].add(node);
}
if (nonNullNode == size) {
if (nonNullNode == size()) {
break;
}
}
lastFreezeSize = size;
lastFreezeSize = size();
}
@Override
public long ramBytesUsed() {
long neighborArrayBytes0 =
(long) nsize0 * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2L
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3;
long neighborArrayBytes =
(long) nsize * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2L
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3;
long total = 0;
total +=
size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
size() * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
total += 8 * Integer.BYTES; // all int fields
total += nonZeroLevelSize.get() * neighborArrayBytes; // for non-zero level
total += 4 * Integer.BYTES; // all int fields
total += 1; // field: noGrowth
total +=
RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ 2 * Integer.BYTES; // field: entryNode
total += 3L * (Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER); // 3 AtomicInteger
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
if (levelToNodes != null) {
total +=
(long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
(long) (numLevels() - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
total +=
(long) nonZeroLevelSize
(long) nonZeroLevelSize.get()
* (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ Integer.BYTES);
@ -274,9 +316,11 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
return "OnHeapHnswGraph(size="
+ size()
+ ", numLevels="
+ numLevels
+ numLevels()
+ ", entryNode="
+ entryNode
+ entryNode()
+ ")";
}
private record EntryNode(int node, int level) {}
}

View File

@ -32,12 +32,14 @@ public interface RandomVectorScorerSupplier {
RandomVectorScorer scorer(int ord) throws IOException;
/**
* Creates a {@link RandomVectorScorerSupplier} to compare float vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
* Make a copy of the supplier, which will copy the underlying vectorValues so the copy is safe to
* be used in other threads.
*/
RandomVectorScorerSupplier copy() throws IOException;
/**
* Creates a {@link RandomVectorScorerSupplier} to compare float vectors. The vectorValues passed
* in will be copied and the original copy will not be used.
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
@ -48,21 +50,12 @@ public interface RandomVectorScorerSupplier {
throws IOException {
// We copy the provided random accessor just once during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
final RandomAccessVectorValues<float[]> vectorsCopy = vectors.copy();
return queryOrd ->
(RandomVectorScorer)
cand ->
similarityFunction.compare(
vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
return new FloatScoringSupplier(vectors, similarityFunction);
}
/**
* Creates a {@link RandomVectorScorerSupplier} to compare byte vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
* Creates a {@link RandomVectorScorerSupplier} to compare byte vectors. The vectorValues passed
* in will be copied and the original copy will not be used.
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
@ -71,13 +64,64 @@ public interface RandomVectorScorerSupplier {
final RandomAccessVectorValues<byte[]> vectors,
final VectorSimilarityFunction similarityFunction)
throws IOException {
// We copy the provided random accessor just once during the supplier's initialization
// We copy the provided random accessor only during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
final RandomAccessVectorValues<byte[]> vectorsCopy = vectors.copy();
return queryOrd ->
(RandomVectorScorer)
cand ->
similarityFunction.compare(
vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
return new ByteScoringSupplier(vectors, similarityFunction);
}
/** RandomVectorScorerSupplier for bytes vector */
final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues<byte[]> vectors;
private final RandomAccessVectorValues<byte[]> vectors1;
private final RandomAccessVectorValues<byte[]> vectors2;
private final VectorSimilarityFunction similarityFunction;
private ByteScoringSupplier(
RandomAccessVectorValues<byte[]> vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return cand ->
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ByteScoringSupplier(vectors, similarityFunction);
}
}
/** RandomVectorScorerSupplier for Float vector */
final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues<float[]> vectors;
private final RandomAccessVectorValues<float[]> vectors1;
private final RandomAccessVectorValues<float[]> vectors2;
private final VectorSimilarityFunction similarityFunction;
private FloatScoringSupplier(
RandomAccessVectorValues<float[]> vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return cand ->
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new FloatScoringSupplier(vectors, similarityFunction);
}
}
}

View File

@ -565,6 +565,14 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
// we cannot call getNodesOnLevel before the graph reaches the size it claimed, so here we
// create
// another graph to do the assertion
OnHeapHnswGraph graphAfterInit =
InitializedHnswGraphBuilder.initGraph(
10, initializerGraph, initializerOrdMap, initializerGraph.size());
HnswGraphBuilder finalBuilder =
InitializedHnswGraphBuilder.fromGraph(
finalscorerSupplier,
@ -578,7 +586,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
totalSize);
// When offset is 0, the graphs should be identical before vectors are added
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
assertGraphEqual(initializerGraph, graphAfterInit);
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size());
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
@ -989,6 +997,33 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
}
/*
* A very basic test ensure the concurrent merge does not throw exceptions, it by no means guarantees the
* true correctness of the concurrent merge and that must be checked manually by running a KNN benchmark
* and comparing the recall
*/
public void testConcurrentMergeBuilder() throws IOException {
int size = atLeast(1000);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
HnswGraphBuilder.randSeed = random().nextLong();
HnswConcurrentMergeBuilder builder =
new HnswConcurrentMergeBuilder(
exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
builder.setBatchSize(100);
builder.build(size);
exec.shutdownNow();
OnHeapHnswGraph graph = builder.getGraph();
assertTrue(graph.entryNode() != -1);
assertEquals(size, graph.size());
assertEquals(size - 1, graph.maxNodeId());
for (int l = 0; l < graph.numLevels(); l++) {
assertNotNull(graph.getNodesOnLevel(l));
}
}
private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);

View File

@ -43,7 +43,7 @@ public class TestOnHeapHnswGraph extends LuceneTestCase {
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
public void testIncompleteGraphThrow() {
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
graph.addNode(1, 0);
graph.addNode(0, 0);
assertEquals(1, graph.getNodesOnLevel(1).size());
@ -62,6 +62,10 @@ public class TestOnHeapHnswGraph extends LuceneTestCase {
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
graph.trySetNewEntryNode(i, l);
if (l > graph.numLevels() - 1) {
graph.tryPromoteNewEntryNode(i, l, graph.numLevels() - 1);
}
levelToNodes.get(l).add(i);
}
}
@ -93,6 +97,10 @@ public class TestOnHeapHnswGraph extends LuceneTestCase {
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
graph.trySetNewEntryNode(i, l);
if (l > graph.numLevels() - 1) {
graph.tryPromoteNewEntryNode(i, l, graph.numLevels() - 1);
}
levelToNodes.get(l).add(i);
}
}

View File

@ -39,6 +39,8 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Predicate;
import java.util.function.ToLongFunction;
import java.util.stream.Collectors;
@ -159,7 +161,11 @@ public final class RamUsageTester {
// Ignore JDK objects we can't access or handle properly.
Predicate<Object> isIgnorable =
(clazz) -> (clazz instanceof CharsetEncoder) || (clazz instanceof CharsetDecoder);
(clazz) ->
(clazz instanceof CharsetEncoder)
|| (clazz instanceof CharsetDecoder)
|| (clazz instanceof ReentrantReadWriteLock)
|| (clazz instanceof AtomicReference<?>);
if (isIgnorable.test(ob)) {
return accumulator.accumulateObject(ob, 0, Collections.emptyMap(), stack);
}