mirror of https://github.com/apache/lucene.git
Concurrent HNSW Merge (#12660)
This commit is contained in:
parent
f5776c8844
commit
a8c52e2e19
|
@ -169,6 +169,9 @@ New Features
|
|||
|
||||
* GITHUB#12582: Add int8 scalar quantization to the HNSW vector format. This optionally allows for more compact lossy
|
||||
storage for the vectors, requiring about 75% memory for fast HNSW search. (Ben Trent)
|
||||
|
||||
* GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat.
|
||||
(Patrick Zhai)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
|
|
@ -476,9 +476,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case FLOAT32 -> mergedVectorIterator =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(docsWithField.cardinality());
|
||||
graph =
|
||||
merger.merge(
|
||||
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
|
@ -151,6 +152,9 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
|||
*/
|
||||
public static final int DEFAULT_BEAM_WIDTH = 100;
|
||||
|
||||
/** Default to use single thread merge */
|
||||
public static final int DEFAULT_NUM_MERGE_WORKER = 1;
|
||||
|
||||
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
|
||||
|
||||
/**
|
||||
|
@ -169,20 +173,36 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
|||
/** Should this codec scalar quantize float32 vectors and use this format */
|
||||
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;
|
||||
|
||||
private final int numMergeWorkers;
|
||||
private final ExecutorService mergeExec;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99HnswVectorsFormat() {
|
||||
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
|
||||
}
|
||||
|
||||
public Lucene99HnswVectorsFormat(
|
||||
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
|
||||
this(maxConn, beamWidth, scalarQuantize, DEFAULT_NUM_MERGE_WORKER, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a format using the given graph construction parameters.
|
||||
*
|
||||
* @param maxConn the maximum number of connections to a node in the HNSW graph
|
||||
* @param beamWidth the size of the queue maintained during graph construction.
|
||||
* @param scalarQuantize the scalar quantization format
|
||||
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
|
||||
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
|
||||
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
|
||||
* generated by this format to do the merge
|
||||
*/
|
||||
public Lucene99HnswVectorsFormat(
|
||||
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
Lucene99ScalarQuantizedVectorsFormat scalarQuantize,
|
||||
int numMergeWorkers,
|
||||
ExecutorService mergeExec) {
|
||||
super("Lucene99HnswVectorsFormat");
|
||||
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
||||
throw new IllegalArgumentException(
|
||||
|
@ -198,14 +218,25 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
|||
+ "; beamWidth="
|
||||
+ beamWidth);
|
||||
}
|
||||
if (numMergeWorkers > 1 && mergeExec == null) {
|
||||
throw new IllegalArgumentException(
|
||||
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
|
||||
}
|
||||
if (numMergeWorkers == 1 && mergeExec != null) {
|
||||
throw new IllegalArgumentException(
|
||||
"No executor service is needed as we'll use single thread to merge");
|
||||
}
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
this.scalarQuantizedVectorsFormat = scalarQuantize;
|
||||
this.numMergeWorkers = numMergeWorkers;
|
||||
this.mergeExec = mergeExec;
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, scalarQuantizedVectorsFormat);
|
||||
return new Lucene99HnswVectorsWriter(
|
||||
state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,6 +28,7 @@ import java.nio.ByteOrder;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
|
@ -52,9 +53,11 @@ import org.apache.lucene.util.InfoStream;
|
|||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphMerger;
|
||||
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
|
@ -75,6 +78,8 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final int M;
|
||||
private final int beamWidth;
|
||||
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
|
||||
private final int numMergeWorkers;
|
||||
private final ExecutorService mergeExec;
|
||||
|
||||
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||
private boolean finished;
|
||||
|
@ -83,10 +88,14 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
SegmentWriteState state,
|
||||
int M,
|
||||
int beamWidth,
|
||||
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat)
|
||||
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat,
|
||||
int numMergeWorkers,
|
||||
ExecutorService mergeExec)
|
||||
throws IOException {
|
||||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
this.numMergeWorkers = numMergeWorkers;
|
||||
this.mergeExec = mergeExec;
|
||||
segmentWriteState = state;
|
||||
String metaFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
|
@ -383,7 +392,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int node = nodesOnLevel0.nextInt();
|
||||
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
|
||||
long offset = vectorIndex.getFilePointer();
|
||||
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
|
||||
reconstructAndWriteNeighbours(neighbors, oldToNewMap, maxOrd);
|
||||
levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offset);
|
||||
}
|
||||
|
||||
|
@ -400,7 +409,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
for (int node : newNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
|
||||
long offset = vectorIndex.getFilePointer();
|
||||
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
|
||||
reconstructAndWriteNeighbours(neighbors, oldToNewMap, maxOrd);
|
||||
levelNodeOffsets[level][nodeOffsetIndex++] =
|
||||
Math.toIntExact(vectorIndex.getFilePointer() - offset);
|
||||
}
|
||||
|
@ -442,7 +451,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
};
|
||||
}
|
||||
|
||||
private void reconstructAndWriteNeigbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
|
||||
private void reconstructAndWriteNeighbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
|
||||
throws IOException {
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeVInt(size);
|
||||
|
@ -557,6 +566,12 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
IOUtils.close(finalVectorDataInput);
|
||||
segmentWriteState.directory.deleteFile(tempFileName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier copy() throws IOException {
|
||||
// here we just return the inner out since we only need to close this outside copy
|
||||
return innerScoreSupplier.copy();
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// No need to use temporary file as we don't have to re-open for reading
|
||||
|
@ -579,8 +594,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int[][] vectorIndexNodeOffsets = null;
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
// build graph
|
||||
IncrementalHnswGraphMerger merger =
|
||||
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
||||
HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
merger.addReader(
|
||||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||
|
@ -592,9 +606,9 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case FLOAT32 -> mergedVectorIterator =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(docsWithField.cardinality());
|
||||
graph =
|
||||
merger.merge(
|
||||
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
@ -675,6 +689,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
return sortedNodes;
|
||||
}
|
||||
|
||||
private HnswGraphMerger createGraphMerger(
|
||||
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
|
||||
if (mergeExec != null) {
|
||||
return new ConcurrentHnswMerger(
|
||||
fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
|
||||
}
|
||||
return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
boolean isQuantized,
|
||||
FieldInfo field,
|
||||
|
@ -819,6 +842,9 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
|
||||
if (mergeExec != null) {
|
||||
mergeExec.shutdownNow();
|
||||
}
|
||||
}
|
||||
|
||||
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||
|
|
|
@ -52,6 +52,7 @@ import org.apache.lucene.util.ScalarQuantizer;
|
|||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
/**
|
||||
* Writes quantized vector values and metadata to index segments.
|
||||
|
@ -761,6 +762,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
|||
return supplier.scorer(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier copy() throws IOException {
|
||||
return supplier.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
onClose.close();
|
||||
|
|
|
@ -39,6 +39,12 @@ final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSco
|
|||
this.values = values;
|
||||
}
|
||||
|
||||
private ScalarQuantizedRandomVectorScorerSupplier(
|
||||
ScalarQuantizedVectorSimilarity similarity, RandomAccessQuantizedByteVectorValues values) {
|
||||
this.similarity = similarity;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
|
||||
|
@ -46,4 +52,9 @@ final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSco
|
|||
final float queryOffset = values.getScoreCorrectionConstant();
|
||||
return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier copy() throws IOException {
|
||||
return new ScalarQuantizedRandomVectorScorerSupplier(similarity, values.copy());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,9 @@ import java.io.Closeable;
|
|||
/**
|
||||
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
|
||||
* close after use
|
||||
*
|
||||
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
|
||||
* closeable
|
||||
*/
|
||||
public interface CloseableRandomVectorScorerSupplier
|
||||
extends Closeable, RandomVectorScorerSupplier {}
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
||||
/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
|
||||
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
||||
|
||||
private final ExecutorService exec;
|
||||
private final int numWorker;
|
||||
|
||||
/**
|
||||
* @param fieldInfo FieldInfo for the field being merged
|
||||
*/
|
||||
public ConcurrentHnswMerger(
|
||||
FieldInfo fieldInfo,
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
ExecutorService exec,
|
||||
int numWorker) {
|
||||
super(fieldInfo, scorerSupplier, M, beamWidth);
|
||||
this.exec = exec;
|
||||
this.numWorker = numWorker;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
|
||||
throws IOException {
|
||||
if (initReader == null) {
|
||||
return new HnswConcurrentMergeBuilder(
|
||||
exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
|
||||
BitSet initializedNodes = new FixedBitSet(maxOrd);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
|
||||
|
||||
return new HnswConcurrentMergeBuilder(
|
||||
exec,
|
||||
numWorker,
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
InitializedHnswGraphBuilder.initGraph(M, initializerGraph, oldToNewOrdinalMap, maxOrd),
|
||||
initializedNodes);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
/**
|
||||
* Interface for builder building the {@link OnHeapHnswGraph}
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface HnswBuilder {
|
||||
|
||||
/**
|
||||
* Adds all nodes to the graph up to the provided {@code maxOrd}.
|
||||
*
|
||||
* @param maxOrd The maximum ordinal (excluded) of the nodes to be added.
|
||||
*/
|
||||
OnHeapHnswGraph build(int maxOrd) throws IOException;
|
||||
|
||||
/** Inserts a doc with vector value to the graph */
|
||||
void addGraphNode(int node) throws IOException;
|
||||
|
||||
/** Set info-stream to output debugging information */
|
||||
void setInfoStream(InfoStream infoStream);
|
||||
|
||||
OnHeapHnswGraph getGraph();
|
||||
}
|
|
@ -0,0 +1,248 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.HNSW_COMPONENT;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.ThreadInterruptedException;
|
||||
|
||||
/**
|
||||
* A graph builder that manages multiple workers, it only supports adding the whole graph all at
|
||||
* once. It will spawn a thread for each worker and the workers will pick the work in batches.
|
||||
*/
|
||||
public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
||||
|
||||
private static final int DEFAULT_BATCH_SIZE =
|
||||
2048; // number of vectors the worker handles sequentially at one batch
|
||||
|
||||
private final ExecutorService exec;
|
||||
private final ConcurrentMergeWorker[] workers;
|
||||
private InfoStream infoStream = InfoStream.getDefault();
|
||||
|
||||
public HnswConcurrentMergeBuilder(
|
||||
ExecutorService exec,
|
||||
int numWorker,
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
OnHeapHnswGraph hnsw,
|
||||
BitSet initializedNodes)
|
||||
throws IOException {
|
||||
this.exec = exec;
|
||||
AtomicInteger workProgress = new AtomicInteger(0);
|
||||
workers = new ConcurrentMergeWorker[numWorker];
|
||||
for (int i = 0; i < numWorker; i++) {
|
||||
workers[i] =
|
||||
new ConcurrentMergeWorker(
|
||||
scorerSupplier.copy(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed,
|
||||
hnsw,
|
||||
initializedNodes,
|
||||
workProgress);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph build(int maxOrd) throws IOException {
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(
|
||||
HNSW_COMPONENT,
|
||||
"build graph from " + maxOrd + " vectors, with " + workers.length + " workers");
|
||||
}
|
||||
List<Future<?>> futures = new ArrayList<>();
|
||||
for (int i = 0; i < workers.length; i++) {
|
||||
int finalI = i;
|
||||
futures.add(
|
||||
exec.submit(
|
||||
() -> {
|
||||
try {
|
||||
workers[finalI].run(maxOrd);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
Throwable exc = null;
|
||||
for (Future<?> future : futures) {
|
||||
try {
|
||||
future.get();
|
||||
} catch (InterruptedException e) {
|
||||
var newException = new ThreadInterruptedException(e);
|
||||
if (exc == null) {
|
||||
exc = newException;
|
||||
} else {
|
||||
exc.addSuppressed(newException);
|
||||
}
|
||||
} catch (ExecutionException e) {
|
||||
if (exc == null) {
|
||||
exc = e.getCause();
|
||||
} else {
|
||||
exc.addSuppressed(e.getCause());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (exc != null) {
|
||||
// The error handling was copied from TaskExecutor. should we just use TaskExecutor instead?
|
||||
throw IOUtils.rethrowAlways(exc);
|
||||
}
|
||||
return workers[0].getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addGraphNode(int node) throws IOException {
|
||||
throw new UnsupportedOperationException("This builder is for merge only");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInfoStream(InfoStream infoStream) {
|
||||
this.infoStream = infoStream;
|
||||
for (HnswBuilder worker : workers) {
|
||||
worker.setInfoStream(infoStream);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph getGraph() {
|
||||
return workers[0].getGraph();
|
||||
}
|
||||
|
||||
/* test only for now */
|
||||
void setBatchSize(int newSize) {
|
||||
for (ConcurrentMergeWorker worker : workers) {
|
||||
worker.batchSize = newSize;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class ConcurrentMergeWorker extends HnswGraphBuilder {
|
||||
|
||||
/**
|
||||
* A common AtomicInteger shared among all workers, used for tracking what's the next vector to
|
||||
* be added to the graph.
|
||||
*/
|
||||
private final AtomicInteger workProgress;
|
||||
|
||||
private final BitSet initializedNodes;
|
||||
private int batchSize = DEFAULT_BATCH_SIZE;
|
||||
|
||||
private ConcurrentMergeWorker(
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
OnHeapHnswGraph hnsw,
|
||||
BitSet initializedNodes,
|
||||
AtomicInteger workProgress)
|
||||
throws IOException {
|
||||
super(
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
seed,
|
||||
hnsw,
|
||||
new MergeSearcher(
|
||||
new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.maxNodeId() + 1)));
|
||||
this.workProgress = workProgress;
|
||||
this.initializedNodes = initializedNodes;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method first try to "reserve" part of work by calling {@link #getStartPos(int)} and then
|
||||
* calling {@link #addVectors(int, int)} to actually add the nodes to the graph. By doing this
|
||||
* we are able to dynamically allocate the work to multiple workers and try to make all of them
|
||||
* finishing around the same time.
|
||||
*/
|
||||
private void run(int maxOrd) throws IOException {
|
||||
int start = getStartPos(maxOrd);
|
||||
int end;
|
||||
while (start != -1) {
|
||||
end = Math.min(maxOrd, start + batchSize);
|
||||
addVectors(start, end);
|
||||
start = getStartPos(maxOrd);
|
||||
}
|
||||
}
|
||||
|
||||
/** Reserve the work by atomically increment the {@link #workProgress} */
|
||||
private int getStartPos(int maxOrd) {
|
||||
int start = workProgress.getAndAdd(batchSize);
|
||||
if (start < maxOrd) {
|
||||
return start;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addGraphNode(int node) throws IOException {
|
||||
if (initializedNodes != null && initializedNodes.get(node)) {
|
||||
return;
|
||||
}
|
||||
super.addGraphNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This searcher will obtain the lock and make a copy of neighborArray when seeking the graph such
|
||||
* that concurrent modification of the graph will not impact the search
|
||||
*/
|
||||
private static class MergeSearcher extends HnswGraphSearcher {
|
||||
private int[] nodeBuffer;
|
||||
private int upto;
|
||||
private int size;
|
||||
|
||||
private MergeSearcher(NeighborQueue candidates, BitSet visited) {
|
||||
super(candidates, visited);
|
||||
}
|
||||
|
||||
@Override
|
||||
void graphSeek(HnswGraph graph, int level, int targetNode) {
|
||||
NeighborArray neighborArray = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
|
||||
neighborArray.rwlock.readLock().lock();
|
||||
try {
|
||||
if (nodeBuffer == null || nodeBuffer.length < neighborArray.size()) {
|
||||
nodeBuffer = new int[neighborArray.size()];
|
||||
}
|
||||
size = neighborArray.size();
|
||||
if (size >= 0) System.arraycopy(neighborArray.node, 0, nodeBuffer, 0, size);
|
||||
} finally {
|
||||
neighborArray.rwlock.readLock().unlock();
|
||||
}
|
||||
upto = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
int graphNextNeighbor(HnswGraph graph) {
|
||||
if (++upto < size) {
|
||||
return nodeBuffer[upto];
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -33,7 +33,7 @@ import org.apache.lucene.util.InfoStream;
|
|||
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
||||
* hyper-parameters.
|
||||
*/
|
||||
public class HnswGraphBuilder {
|
||||
public class HnswGraphBuilder implements HnswBuilder {
|
||||
|
||||
/** Default number of maximum connections per node */
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
|
@ -54,7 +54,6 @@ public class HnswGraphBuilder {
|
|||
|
||||
private final int M; // max number of connections on upper layers
|
||||
private final double ml;
|
||||
private final NeighborArray scratch;
|
||||
|
||||
private final SplittableRandom random;
|
||||
private final RandomVectorScorerSupplier scorerSupplier;
|
||||
|
@ -97,6 +96,22 @@ public class HnswGraphBuilder {
|
|||
this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
|
||||
}
|
||||
|
||||
protected HnswGraphBuilder(
|
||||
RandomVectorScorerSupplier scorerSupplier,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
OnHeapHnswGraph hnsw)
|
||||
throws IOException {
|
||||
this(
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
seed,
|
||||
hnsw,
|
||||
new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.size())));
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from vector values, builds a graph connecting them by their dense
|
||||
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
|
||||
|
@ -114,7 +129,8 @@ public class HnswGraphBuilder {
|
|||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
OnHeapHnswGraph hnsw)
|
||||
OnHeapHnswGraph hnsw,
|
||||
HnswGraphSearcher graphSearcher)
|
||||
throws IOException {
|
||||
if (M <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
|
@ -129,20 +145,12 @@ public class HnswGraphBuilder {
|
|||
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
this.hnsw = hnsw;
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
|
||||
// in scratch we store candidates in reverse order: worse candidates are first
|
||||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
|
||||
this.graphSearcher = graphSearcher;
|
||||
entryCandidates = new GraphBuilderKnnCollector(1);
|
||||
beamCandidates = new GraphBuilderKnnCollector(beamWidth);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds all nodes to the graph up to the provided {@code maxOrd}.
|
||||
*
|
||||
* @param maxOrd The maximum ordinal of the nodes to be added.
|
||||
*/
|
||||
@Override
|
||||
public OnHeapHnswGraph build(int maxOrd) throws IOException {
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
|
||||
|
@ -151,18 +159,23 @@ public class HnswGraphBuilder {
|
|||
return hnsw;
|
||||
}
|
||||
|
||||
/** Set info-stream to output debugging information * */
|
||||
@Override
|
||||
public void setInfoStream(InfoStream infoStream) {
|
||||
this.infoStream = infoStream;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph getGraph() {
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
private void addVectors(int maxOrd) throws IOException {
|
||||
/** add vectors in range [minOrd, maxOrd) */
|
||||
protected void addVectors(int minOrd, int maxOrd) throws IOException {
|
||||
long start = System.nanoTime(), t = start;
|
||||
for (int node = 0; node < maxOrd; node++) {
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
|
||||
}
|
||||
for (int node = minOrd; node < maxOrd; node++) {
|
||||
addGraphNode(node);
|
||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
t = printGraphBuildStatus(node, start, t);
|
||||
|
@ -170,42 +183,98 @@ public class HnswGraphBuilder {
|
|||
}
|
||||
}
|
||||
|
||||
/** Inserts a doc with vector value to the graph */
|
||||
private void addVectors(int maxOrd) throws IOException {
|
||||
addVectors(0, maxOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addGraphNode(int node) throws IOException {
|
||||
/*
|
||||
Note: this implementation is thread safe when graph size is fixed (e.g. when merging)
|
||||
The process of adding a node is roughly:
|
||||
1. Add the node to all level from top to the bottom, but do not connect it to any other node,
|
||||
nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty
|
||||
and this is the first node, in that case we set the entry node and return)
|
||||
2. Do the search from top to bottom, remember all the possible neighbours on each level the node
|
||||
is on.
|
||||
3. Add the neighbor to the node from bottom to top level, when adding the neighbour,
|
||||
we always add all the outgoing links first before adding incoming link such that
|
||||
when a search visits this node, it can always find a way out
|
||||
4. If the node has level that is less or equal to graph level, then we're done here.
|
||||
If the node has level larger than graph level, then we need to promote the node
|
||||
as the entry node. If, while we add the node to the graph, the entry node has changed
|
||||
(which means the graph level has changed as well), we need to reinsert the node
|
||||
to the newly introduced levels (repeating step 2,3 for new levels) and again try to
|
||||
promote the node to entry node.
|
||||
*/
|
||||
RandomVectorScorer scorer = scorerSupplier.scorer(node);
|
||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||
int curMaxLevel = hnsw.numLevels() - 1;
|
||||
|
||||
// If entrynode is -1, then this should finish without adding neighbors
|
||||
if (hnsw.entryNode() == -1) {
|
||||
for (int level = nodeLevel; level >= 0; level--) {
|
||||
hnsw.addNode(level, node);
|
||||
}
|
||||
// first add nodes to all levels
|
||||
for (int level = nodeLevel; level >= 0; level--) {
|
||||
hnsw.addNode(level, node);
|
||||
}
|
||||
// then promote itself as entry node if entry node is not set
|
||||
if (hnsw.trySetNewEntryNode(node, nodeLevel)) {
|
||||
return;
|
||||
}
|
||||
int[] eps = new int[] {hnsw.entryNode()};
|
||||
// if the entry node is already set, then we have to do all connections first before we can
|
||||
// promote ourselves as entry node
|
||||
|
||||
// if a node introduces new levels to the graph, add this new node on new levels
|
||||
for (int level = nodeLevel; level > curMaxLevel; level--) {
|
||||
hnsw.addNode(level, node);
|
||||
}
|
||||
int lowestUnsetLevel = 0;
|
||||
int curMaxLevel;
|
||||
do {
|
||||
curMaxLevel = hnsw.numLevels() - 1;
|
||||
// NOTE: the entry node and max level may not be paired, but because we get the level first
|
||||
// we ensure that the entry node we get later will always exist on the curMaxLevel
|
||||
int[] eps = new int[] {hnsw.entryNode()};
|
||||
|
||||
// for levels > nodeLevel search with topk = 1
|
||||
GraphBuilderKnnCollector candidates = entryCandidates;
|
||||
for (int level = curMaxLevel; level > nodeLevel; level--) {
|
||||
candidates.clear();
|
||||
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
|
||||
eps = new int[] {candidates.popNode()};
|
||||
}
|
||||
// for levels <= nodeLevel search with topk = beamWidth, and add connections
|
||||
candidates = beamCandidates;
|
||||
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
|
||||
candidates.clear();
|
||||
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
|
||||
eps = candidates.popUntilNearestKNodes();
|
||||
hnsw.addNode(level, node);
|
||||
addDiverseNeighbors(level, node, candidates);
|
||||
}
|
||||
// we first do the search from top to bottom
|
||||
// for levels > nodeLevel search with topk = 1
|
||||
GraphBuilderKnnCollector candidates = entryCandidates;
|
||||
for (int level = curMaxLevel; level > nodeLevel; level--) {
|
||||
candidates.clear();
|
||||
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
|
||||
eps[0] = candidates.popNode();
|
||||
}
|
||||
|
||||
// for levels <= nodeLevel search with topk = beamWidth, and add connections
|
||||
candidates = beamCandidates;
|
||||
NeighborArray[] scratchPerLevel =
|
||||
new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1];
|
||||
for (int i = scratchPerLevel.length - 1; i >= 0; i--) {
|
||||
int level = i + lowestUnsetLevel;
|
||||
candidates.clear();
|
||||
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
|
||||
eps = candidates.popUntilNearestKNodes();
|
||||
scratchPerLevel[i] = new NeighborArray(Math.max(beamCandidates.k(), M + 1), false);
|
||||
popToScratch(candidates, scratchPerLevel[i]);
|
||||
}
|
||||
|
||||
// then do connections from bottom up
|
||||
for (int i = 0; i < scratchPerLevel.length; i++) {
|
||||
addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]);
|
||||
}
|
||||
lowestUnsetLevel += scratchPerLevel.length;
|
||||
assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
|
||||
if (lowestUnsetLevel > nodeLevel) {
|
||||
return;
|
||||
}
|
||||
assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel;
|
||||
if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) {
|
||||
return;
|
||||
}
|
||||
if (hnsw.numLevels() == curMaxLevel + 1) {
|
||||
// This should never happen if all the calculations are correct
|
||||
throw new IllegalStateException(
|
||||
"We're not able to promote node "
|
||||
+ node
|
||||
+ " at level "
|
||||
+ nodeLevel
|
||||
+ " as entry node. But the max graph level "
|
||||
+ curMaxLevel
|
||||
+ " has not changed while we are inserting the node.");
|
||||
}
|
||||
} while (true);
|
||||
}
|
||||
|
||||
private long printGraphBuildStatus(int node, long start, long t) {
|
||||
|
@ -221,7 +290,7 @@ public class HnswGraphBuilder {
|
|||
return now;
|
||||
}
|
||||
|
||||
private void addDiverseNeighbors(int level, int node, GraphBuilderKnnCollector candidates)
|
||||
private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
|
||||
throws IOException {
|
||||
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
|
||||
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
||||
|
@ -229,26 +298,40 @@ public class HnswGraphBuilder {
|
|||
*/
|
||||
NeighborArray neighbors = hnsw.getNeighbors(level, node);
|
||||
assert neighbors.size() == 0; // new node
|
||||
popToScratch(candidates);
|
||||
int maxConnOnLevel = level == 0 ? M * 2 : M;
|
||||
selectAndLinkDiverse(neighbors, scratch, maxConnOnLevel);
|
||||
boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel);
|
||||
|
||||
// Link the selected nodes to the new node, and the new node to the selected nodes (again
|
||||
// applying diversity heuristic)
|
||||
int size = neighbors.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
int nbr = neighbors.node[i];
|
||||
// NOTE: here we're using candidates and mask but not the neighbour array because once we have
|
||||
// added incoming link there will be possibilities of this node being discovered and neighbour
|
||||
// array being modified. So using local candidates and mask is a safer option.
|
||||
for (int i = 0; i < candidates.size(); i++) {
|
||||
if (mask[i] == false) {
|
||||
continue;
|
||||
}
|
||||
int nbr = candidates.node[i];
|
||||
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
|
||||
nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
|
||||
if (nbrsOfNbr.size() > maxConnOnLevel) {
|
||||
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
|
||||
nbrsOfNbr.removeIndex(indexToRemove);
|
||||
nbrsOfNbr.rwlock.writeLock().lock();
|
||||
try {
|
||||
nbrsOfNbr.addOutOfOrder(node, candidates.score[i]);
|
||||
if (nbrsOfNbr.size() > maxConnOnLevel) {
|
||||
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
|
||||
nbrsOfNbr.removeIndex(indexToRemove);
|
||||
}
|
||||
} finally {
|
||||
nbrsOfNbr.rwlock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void selectAndLinkDiverse(
|
||||
/**
|
||||
* This method will select neighbors to add and return a mask telling the caller which candidates
|
||||
* are selected
|
||||
*/
|
||||
private boolean[] selectAndLinkDiverse(
|
||||
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
|
||||
boolean[] mask = new boolean[candidates.size()];
|
||||
// Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
|
||||
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
|
||||
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
||||
|
@ -257,12 +340,16 @@ public class HnswGraphBuilder {
|
|||
float cScore = candidates.score[i];
|
||||
assert cNode <= hnsw.maxNodeId();
|
||||
if (diversityCheck(cNode, cScore, neighbors)) {
|
||||
mask[i] = true;
|
||||
// here we don't need to lock, because there's no incoming link so no others is able to
|
||||
// discover this node such that no others will modify this neighbor array as well
|
||||
neighbors.addInOrder(cNode, cScore);
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
private void popToScratch(GraphBuilderKnnCollector candidates) {
|
||||
private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) {
|
||||
scratch.clear();
|
||||
int candidateCount = candidates.size();
|
||||
// extract all the Neighbors from the queue into an array; these will now be
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
/**
|
||||
* Abstraction of merging multiple graphs into one on-heap graph
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface HnswGraphMerger {
|
||||
|
||||
/**
|
||||
* Adds a reader to the graph merger to record the state
|
||||
*
|
||||
* @param reader KnnVectorsReader to add to the merger
|
||||
* @param docMap MergeState.DocMap for the reader
|
||||
* @param liveDocs Bits representing live docs, can be null
|
||||
* @return this
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
HnswGraphMerger addReader(KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Merge and produce the on heap graph
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @param infoStream optional info stream to set to builder
|
||||
* @param maxOrd max number of vectors that will be added to the graph
|
||||
* @return merged graph
|
||||
* @throws IOException during merge
|
||||
*/
|
||||
OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd)
|
||||
throws IOException;
|
||||
}
|
|
@ -32,6 +32,7 @@ import org.apache.lucene.util.BitSet;
|
|||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.CollectionUtil;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
/**
|
||||
* This selects the biggest Hnsw graph from the provided merge state and initializes a new
|
||||
|
@ -39,15 +40,16 @@ import org.apache.lucene.util.FixedBitSet;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class IncrementalHnswGraphMerger {
|
||||
public class IncrementalHnswGraphMerger implements HnswGraphMerger {
|
||||
|
||||
private KnnVectorsReader initReader;
|
||||
private MergeState.DocMap initDocMap;
|
||||
private int initGraphSize;
|
||||
private final FieldInfo fieldInfo;
|
||||
private final RandomVectorScorerSupplier scorerSupplier;
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
protected final FieldInfo fieldInfo;
|
||||
protected final RandomVectorScorerSupplier scorerSupplier;
|
||||
protected final int M;
|
||||
protected final int beamWidth;
|
||||
|
||||
protected KnnVectorsReader initReader;
|
||||
protected MergeState.DocMap initDocMap;
|
||||
protected int initGraphSize;
|
||||
|
||||
/**
|
||||
* @param fieldInfo FieldInfo for the field being merged
|
||||
|
@ -64,13 +66,8 @@ public class IncrementalHnswGraphMerger {
|
|||
* Adds a reader to the graph merger if it meets the following criteria: 1. Does not contain any
|
||||
* deleted docs 2. Is a HnswGraphProvider/PerFieldKnnVectorReader 3. Has the most docs of any
|
||||
* previous reader that met the above criteria
|
||||
*
|
||||
* @param reader KnnVectorsReader to add to the merger
|
||||
* @param docMap MergeState.DocMap for the reader
|
||||
* @param liveDocs Bits representing live docs, can be null
|
||||
* @return this
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
@Override
|
||||
public IncrementalHnswGraphMerger addReader(
|
||||
KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException {
|
||||
KnnVectorsReader currKnnVectorsReader = reader;
|
||||
|
@ -113,18 +110,20 @@ public class IncrementalHnswGraphMerger {
|
|||
* If no valid readers were added to the merge state, a new graph is created.
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @param maxOrd max num of vectors that will be merged into the graph
|
||||
* @return HnswGraphBuilder
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
public HnswGraphBuilder createBuilder(DocIdSetIterator mergedVectorIterator) throws IOException {
|
||||
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
|
||||
throws IOException {
|
||||
if (initReader == null) {
|
||||
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
return HnswGraphBuilder.create(
|
||||
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, maxOrd);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
|
||||
final int numVectors = Math.toIntExact(mergedVectorIterator.cost());
|
||||
|
||||
BitSet initializedNodes = new FixedBitSet(numVectors + 1);
|
||||
BitSet initializedNodes = new FixedBitSet(maxOrd);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
|
||||
return InitializedHnswGraphBuilder.fromGraph(
|
||||
scorerSupplier,
|
||||
|
@ -134,7 +133,15 @@ public class IncrementalHnswGraphMerger {
|
|||
initializerGraph,
|
||||
oldToNewOrdinalMap,
|
||||
initializedNodes,
|
||||
numVectors);
|
||||
maxOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnHeapHnswGraph merge(
|
||||
DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException {
|
||||
HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd);
|
||||
builder.setInfoStream(infoStream);
|
||||
return builder.build(maxOrd);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -146,8 +153,8 @@ public class IncrementalHnswGraphMerger {
|
|||
* @return the mapping from old ordinals to new ordinals
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
private int[] getNewOrdMapping(DocIdSetIterator mergedVectorIterator, BitSet initializedNodes)
|
||||
throws IOException {
|
||||
protected final int[] getNewOrdMapping(
|
||||
DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException {
|
||||
DocIdSetIterator initializerIterator = null;
|
||||
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
|
|
|
@ -55,6 +55,18 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
|
|||
BitSet initializedNodes,
|
||||
int totalNumberOfVectors)
|
||||
throws IOException {
|
||||
return new InitializedHnswGraphBuilder(
|
||||
scorerSupplier,
|
||||
M,
|
||||
beamWidth,
|
||||
seed,
|
||||
initGraph(M, initializerGraph, newOrdMap, totalNumberOfVectors),
|
||||
initializedNodes);
|
||||
}
|
||||
|
||||
public static OnHeapHnswGraph initGraph(
|
||||
int M, HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors)
|
||||
throws IOException {
|
||||
OnHeapHnswGraph hnsw = new OnHeapHnswGraph(M, totalNumberOfVectors);
|
||||
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
|
||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||
|
@ -62,6 +74,7 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
|
|||
int oldOrd = it.nextInt();
|
||||
int newOrd = newOrdMap[oldOrd];
|
||||
hnsw.addNode(level, newOrd);
|
||||
hnsw.trySetNewEntryNode(newOrd, level);
|
||||
NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd);
|
||||
initializerGraph.seek(level, oldOrd);
|
||||
for (int oldNeighbor = initializerGraph.nextNeighbor();
|
||||
|
@ -73,8 +86,7 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
|
|||
}
|
||||
}
|
||||
}
|
||||
return new InitializedHnswGraphBuilder(
|
||||
scorerSupplier, M, beamWidth, seed, hnsw, initializedNodes);
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
private final BitSet initializedNodes;
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.lucene.util.hnsw;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
/**
|
||||
|
@ -35,6 +37,7 @@ public class NeighborArray {
|
|||
float[] score;
|
||||
int[] node;
|
||||
private int sortedNodeSize;
|
||||
public final ReadWriteLock rwlock = new ReentrantReadWriteLock(true);
|
||||
|
||||
public NeighborArray(int maxSize, boolean descOrder) {
|
||||
node = new int[maxSize];
|
||||
|
|
|
@ -21,6 +21,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
@ -33,8 +35,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
|
||||
private static final int INIT_SIZE = 128;
|
||||
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level. -1 if not set
|
||||
private final AtomicReference<EntryNode> entryNode;
|
||||
|
||||
// the internal graph representation where the first dimension is node id and second dimension is
|
||||
// level
|
||||
|
@ -47,11 +48,13 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
private int
|
||||
lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
|
||||
// levelToNodes
|
||||
private int size; // graph size, which is number of nodes in level 0
|
||||
private int
|
||||
nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
|
||||
private final AtomicInteger size =
|
||||
new AtomicInteger(0); // graph size, which is number of nodes in level 0
|
||||
private final AtomicInteger nonZeroLevelSize =
|
||||
new AtomicInteger(
|
||||
0); // total number of NeighborArrays created that is not on level 0, for now it
|
||||
// is only used to account memory usage
|
||||
private int maxNodeId;
|
||||
private final AtomicInteger maxNodeId = new AtomicInteger(-1);
|
||||
private final int nsize; // neighbour array size at non-zero level
|
||||
private final int nsize0; // neighbour array size at zero level
|
||||
private final boolean
|
||||
|
@ -69,11 +72,9 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
* growing itself (you cannot add a node with has id >= numNodes)
|
||||
*/
|
||||
OnHeapHnswGraph(int M, int numNodes) {
|
||||
this.numLevels = 1; // Implicitly start the graph with a single level
|
||||
this.entryNode = -1; // Entry node should be negative until a node is added
|
||||
this.entryNode = new AtomicReference<>(new EntryNode(-1, 1));
|
||||
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
|
||||
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
|
||||
this.maxNodeId = -1;
|
||||
this.nsize = M + 1;
|
||||
this.nsize0 = (M * 2 + 1);
|
||||
noGrowth = numNodes != -1;
|
||||
|
@ -96,7 +97,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
return size.get();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -107,7 +108,16 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
*/
|
||||
@Override
|
||||
public int maxNodeId() {
|
||||
return maxNodeId;
|
||||
if (noGrowth) {
|
||||
// we know the eventual graph size and the graph can possibly
|
||||
// being concurrently modified
|
||||
return graph.length - 1;
|
||||
} else {
|
||||
// The graph cannot be concurrently modified (and searched) if
|
||||
// we don't know the size beforehand, so it's safe to return the
|
||||
// actual maxNodeId
|
||||
return maxNodeId.get();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -120,9 +130,6 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
* @param node the node to add, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public void addNode(int level, int node) {
|
||||
if (entryNode == -1) {
|
||||
entryNode = node;
|
||||
}
|
||||
|
||||
if (node >= graph.length) {
|
||||
if (noGrowth) {
|
||||
|
@ -132,25 +139,20 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
graph = ArrayUtil.grow(graph, node + 1);
|
||||
}
|
||||
|
||||
if (level >= numLevels) {
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
}
|
||||
|
||||
assert graph[node] == null || graph[node].length > level
|
||||
: "node must be inserted from the top level";
|
||||
if (graph[node] == null) {
|
||||
graph[node] =
|
||||
new NeighborArray[level + 1]; // assumption: we always call this function from top level
|
||||
size++;
|
||||
size.incrementAndGet();
|
||||
}
|
||||
if (level == 0) {
|
||||
graph[node][level] = new NeighborArray(nsize0, true);
|
||||
} else {
|
||||
graph[node][level] = new NeighborArray(nsize, true);
|
||||
nonZeroLevelSize++;
|
||||
nonZeroLevelSize.incrementAndGet();
|
||||
}
|
||||
maxNodeId = Math.max(maxNodeId, node);
|
||||
maxNodeId.accumulateAndGet(node, Math::max);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -174,7 +176,7 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
*/
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return numLevels;
|
||||
return entryNode.get().level + 1;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -185,7 +187,41 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
*/
|
||||
@Override
|
||||
public int entryNode() {
|
||||
return entryNode;
|
||||
return entryNode.get().node;
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to set the entry node if the graph does not have one
|
||||
*
|
||||
* @return True if the entry node is set to the provided node. False if the entry node already
|
||||
* exists
|
||||
*/
|
||||
public boolean trySetNewEntryNode(int node, int level) {
|
||||
EntryNode current = entryNode.get();
|
||||
if (current.node == -1) {
|
||||
return entryNode.compareAndSet(current, new EntryNode(node, level));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to promote the provided node to the entry node
|
||||
*
|
||||
* @param level should be larger than expectedOldLevel
|
||||
* @param expectOldLevel is the old entry node level the caller expect to be, the actual graph
|
||||
* level can be different due to concurrent modification
|
||||
* @return True if the entry node is set to the provided node. False if expectOldLevel is not the
|
||||
* same as the current entry node level. Even if the provided node's level is still higher
|
||||
* than the current entry node level, the new entry node will not be set and false will be
|
||||
* returned.
|
||||
*/
|
||||
public boolean tryPromoteNewEntryNode(int node, int level, int expectOldLevel) {
|
||||
assert level > expectOldLevel;
|
||||
EntryNode currentEntry = entryNode.get();
|
||||
if (currentEntry.level == expectOldLevel) {
|
||||
return entryNode.compareAndSet(currentEntry, new EntryNode(node, level));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -212,12 +248,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
|
||||
@SuppressWarnings({"unchecked", "rawtypes"})
|
||||
private void generateLevelToNodes() {
|
||||
if (lastFreezeSize == size) {
|
||||
if (lastFreezeSize == size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
levelToNodes = new List[numLevels];
|
||||
for (int i = 1; i < numLevels; i++) {
|
||||
int maxLevels = numLevels();
|
||||
levelToNodes = new List[maxLevels];
|
||||
for (int i = 1; i < maxLevels; i++) {
|
||||
levelToNodes[i] = new ArrayList<>();
|
||||
}
|
||||
int nonNullNode = 0;
|
||||
|
@ -230,38 +266,44 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
for (int i = 1; i < graph[node].length; i++) {
|
||||
levelToNodes[i].add(node);
|
||||
}
|
||||
if (nonNullNode == size) {
|
||||
if (nonNullNode == size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
lastFreezeSize = size;
|
||||
lastFreezeSize = size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long neighborArrayBytes0 =
|
||||
(long) nsize0 * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2L
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
|
||||
+ Integer.BYTES * 3;
|
||||
long neighborArrayBytes =
|
||||
(long) nsize * (Integer.BYTES + Float.BYTES)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2L
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
|
||||
+ Integer.BYTES * 3;
|
||||
long total = 0;
|
||||
total +=
|
||||
size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
size() * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
|
||||
total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
|
||||
total += 8 * Integer.BYTES; // all int fields
|
||||
total += nonZeroLevelSize.get() * neighborArrayBytes; // for non-zero level
|
||||
total += 4 * Integer.BYTES; // all int fields
|
||||
total += 1; // field: noGrowth
|
||||
total +=
|
||||
RamUsageEstimator.NUM_BYTES_OBJECT_REF
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||
+ 2 * Integer.BYTES; // field: entryNode
|
||||
total += 3L * (Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER); // 3 AtomicInteger
|
||||
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
|
||||
total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
|
||||
if (levelToNodes != null) {
|
||||
total +=
|
||||
(long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
|
||||
(long) (numLevels() - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
|
||||
total +=
|
||||
(long) nonZeroLevelSize
|
||||
(long) nonZeroLevelSize.get()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
|
||||
+ Integer.BYTES);
|
||||
|
@ -274,9 +316,11 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
return "OnHeapHnswGraph(size="
|
||||
+ size()
|
||||
+ ", numLevels="
|
||||
+ numLevels
|
||||
+ numLevels()
|
||||
+ ", entryNode="
|
||||
+ entryNode
|
||||
+ entryNode()
|
||||
+ ")";
|
||||
}
|
||||
|
||||
private record EntryNode(int node, int level) {}
|
||||
}
|
||||
|
|
|
@ -32,12 +32,14 @@ public interface RandomVectorScorerSupplier {
|
|||
RandomVectorScorer scorer(int ord) throws IOException;
|
||||
|
||||
/**
|
||||
* Creates a {@link RandomVectorScorerSupplier} to compare float vectors.
|
||||
*
|
||||
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
|
||||
* using it after calling this function. If you plan to use it again outside the returned {@link
|
||||
* RandomVectorScorer}, think about passing a copied version ({@link
|
||||
* RandomAccessVectorValues#copy}).
|
||||
* Make a copy of the supplier, which will copy the underlying vectorValues so the copy is safe to
|
||||
* be used in other threads.
|
||||
*/
|
||||
RandomVectorScorerSupplier copy() throws IOException;
|
||||
|
||||
/**
|
||||
* Creates a {@link RandomVectorScorerSupplier} to compare float vectors. The vectorValues passed
|
||||
* in will be copied and the original copy will not be used.
|
||||
*
|
||||
* @param vectors the underlying storage for vectors
|
||||
* @param similarityFunction the similarity function to score vectors
|
||||
|
@ -48,21 +50,12 @@ public interface RandomVectorScorerSupplier {
|
|||
throws IOException {
|
||||
// We copy the provided random accessor just once during the supplier's initialization
|
||||
// and then reuse it consistently across all scorers for conducting vector comparisons.
|
||||
final RandomAccessVectorValues<float[]> vectorsCopy = vectors.copy();
|
||||
return queryOrd ->
|
||||
(RandomVectorScorer)
|
||||
cand ->
|
||||
similarityFunction.compare(
|
||||
vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
|
||||
return new FloatScoringSupplier(vectors, similarityFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link RandomVectorScorerSupplier} to compare byte vectors.
|
||||
*
|
||||
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
|
||||
* using it after calling this function. If you plan to use it again outside the returned {@link
|
||||
* RandomVectorScorer}, think about passing a copied version ({@link
|
||||
* RandomAccessVectorValues#copy}).
|
||||
* Creates a {@link RandomVectorScorerSupplier} to compare byte vectors. The vectorValues passed
|
||||
* in will be copied and the original copy will not be used.
|
||||
*
|
||||
* @param vectors the underlying storage for vectors
|
||||
* @param similarityFunction the similarity function to score vectors
|
||||
|
@ -71,13 +64,64 @@ public interface RandomVectorScorerSupplier {
|
|||
final RandomAccessVectorValues<byte[]> vectors,
|
||||
final VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
// We copy the provided random accessor just once during the supplier's initialization
|
||||
// We copy the provided random accessor only during the supplier's initialization
|
||||
// and then reuse it consistently across all scorers for conducting vector comparisons.
|
||||
final RandomAccessVectorValues<byte[]> vectorsCopy = vectors.copy();
|
||||
return queryOrd ->
|
||||
(RandomVectorScorer)
|
||||
cand ->
|
||||
similarityFunction.compare(
|
||||
vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
|
||||
return new ByteScoringSupplier(vectors, similarityFunction);
|
||||
}
|
||||
|
||||
/** RandomVectorScorerSupplier for bytes vector */
|
||||
final class ByteScoringSupplier implements RandomVectorScorerSupplier {
|
||||
private final RandomAccessVectorValues<byte[]> vectors;
|
||||
private final RandomAccessVectorValues<byte[]> vectors1;
|
||||
private final RandomAccessVectorValues<byte[]> vectors2;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
private ByteScoringSupplier(
|
||||
RandomAccessVectorValues<byte[]> vectors, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
this.vectors = vectors;
|
||||
vectors1 = vectors.copy();
|
||||
vectors2 = vectors.copy();
|
||||
this.similarityFunction = similarityFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
return cand ->
|
||||
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier copy() throws IOException {
|
||||
return new ByteScoringSupplier(vectors, similarityFunction);
|
||||
}
|
||||
}
|
||||
|
||||
/** RandomVectorScorerSupplier for Float vector */
|
||||
final class FloatScoringSupplier implements RandomVectorScorerSupplier {
|
||||
private final RandomAccessVectorValues<float[]> vectors;
|
||||
private final RandomAccessVectorValues<float[]> vectors1;
|
||||
private final RandomAccessVectorValues<float[]> vectors2;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
private FloatScoringSupplier(
|
||||
RandomAccessVectorValues<float[]> vectors, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
this.vectors = vectors;
|
||||
vectors1 = vectors.copy();
|
||||
vectors2 = vectors.copy();
|
||||
this.similarityFunction = similarityFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
return cand ->
|
||||
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier copy() throws IOException {
|
||||
return new FloatScoringSupplier(vectors, similarityFunction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -565,6 +565,14 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
|
||||
|
||||
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
|
||||
|
||||
// we cannot call getNodesOnLevel before the graph reaches the size it claimed, so here we
|
||||
// create
|
||||
// another graph to do the assertion
|
||||
OnHeapHnswGraph graphAfterInit =
|
||||
InitializedHnswGraphBuilder.initGraph(
|
||||
10, initializerGraph, initializerOrdMap, initializerGraph.size());
|
||||
|
||||
HnswGraphBuilder finalBuilder =
|
||||
InitializedHnswGraphBuilder.fromGraph(
|
||||
finalscorerSupplier,
|
||||
|
@ -578,7 +586,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
totalSize);
|
||||
|
||||
// When offset is 0, the graphs should be identical before vectors are added
|
||||
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
|
||||
assertGraphEqual(initializerGraph, graphAfterInit);
|
||||
|
||||
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size());
|
||||
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
|
||||
|
@ -989,6 +997,33 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* A very basic test ensure the concurrent merge does not throw exceptions, it by no means guarantees the
|
||||
* true correctness of the concurrent merge and that must be checked manually by running a KNN benchmark
|
||||
* and comparing the recall
|
||||
*/
|
||||
public void testConcurrentMergeBuilder() throws IOException {
|
||||
int size = atLeast(1000);
|
||||
int dim = atLeast(10);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
|
||||
HnswGraphBuilder.randSeed = random().nextLong();
|
||||
HnswConcurrentMergeBuilder builder =
|
||||
new HnswConcurrentMergeBuilder(
|
||||
exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
|
||||
builder.setBatchSize(100);
|
||||
builder.build(size);
|
||||
exec.shutdownNow();
|
||||
OnHeapHnswGraph graph = builder.getGraph();
|
||||
assertTrue(graph.entryNode() != -1);
|
||||
assertEquals(size, graph.size());
|
||||
assertEquals(size - 1, graph.maxNodeId());
|
||||
for (int l = 0; l < graph.numLevels(); l++) {
|
||||
assertNotNull(graph.getNodesOnLevel(l));
|
||||
}
|
||||
}
|
||||
|
||||
private int computeOverlap(int[] a, int[] b) {
|
||||
Arrays.sort(a);
|
||||
Arrays.sort(b);
|
||||
|
|
|
@ -43,7 +43,7 @@ public class TestOnHeapHnswGraph extends LuceneTestCase {
|
|||
|
||||
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
|
||||
public void testIncompleteGraphThrow() {
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
|
||||
OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
|
||||
graph.addNode(1, 0);
|
||||
graph.addNode(0, 0);
|
||||
assertEquals(1, graph.getNodesOnLevel(1).size());
|
||||
|
@ -62,6 +62,10 @@ public class TestOnHeapHnswGraph extends LuceneTestCase {
|
|||
int level = random().nextInt(maxLevel);
|
||||
for (int l = level; l >= 0; l--) {
|
||||
graph.addNode(l, i);
|
||||
graph.trySetNewEntryNode(i, l);
|
||||
if (l > graph.numLevels() - 1) {
|
||||
graph.tryPromoteNewEntryNode(i, l, graph.numLevels() - 1);
|
||||
}
|
||||
levelToNodes.get(l).add(i);
|
||||
}
|
||||
}
|
||||
|
@ -93,6 +97,10 @@ public class TestOnHeapHnswGraph extends LuceneTestCase {
|
|||
int level = random().nextInt(maxLevel);
|
||||
for (int l = level; l >= 0; l--) {
|
||||
graph.addNode(l, i);
|
||||
graph.trySetNewEntryNode(i, l);
|
||||
if (l > graph.numLevels() - 1) {
|
||||
graph.tryPromoteNewEntryNode(i, l, graph.numLevels() - 1);
|
||||
}
|
||||
levelToNodes.get(l).add(i);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,8 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.function.ToLongFunction;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -159,7 +161,11 @@ public final class RamUsageTester {
|
|||
|
||||
// Ignore JDK objects we can't access or handle properly.
|
||||
Predicate<Object> isIgnorable =
|
||||
(clazz) -> (clazz instanceof CharsetEncoder) || (clazz instanceof CharsetDecoder);
|
||||
(clazz) ->
|
||||
(clazz instanceof CharsetEncoder)
|
||||
|| (clazz instanceof CharsetDecoder)
|
||||
|| (clazz instanceof ReentrantReadWriteLock)
|
||||
|| (clazz instanceof AtomicReference<?>);
|
||||
if (isIgnorable.test(ob)) {
|
||||
return accumulator.accumulateObject(ob, 0, Collections.emptyMap(), stack);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue