Add HnswGraphBuilder.getCompletedGraph() and record completed state (#13561)

This commit is contained in:
Michael Sokolov 2024-07-11 11:44:49 -04:00 committed by GitHub
parent cc854f408f
commit 8d1e624a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 4 deletions

View File

@ -615,7 +615,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
OnHeapHnswGraph getGraph() {
assert flatFieldVectorsWriter.isFinished();
if (node > 0) {
return hnswGraphBuilder.getGraph();
return hnswGraphBuilder.getCompletedGraph();
} else {
return null;
}

View File

@ -41,4 +41,12 @@ public interface HnswBuilder {
void setInfoStream(InfoStream infoStream);
OnHeapHnswGraph getGraph();
/**
* Once this method is called no further updates to the graph are accepted (addGraphNode will
* throw IllegalStateException). Final modifications to the graph (eg patching up disconnected
* components, re-ordering node ids for better delta compression) may be triggered, so callers
* should expect this call to take some time.
*/
OnHeapHnswGraph getCompletedGraph();
}

View File

@ -41,6 +41,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
private final TaskExecutor taskExecutor;
private final ConcurrentMergeWorker[] workers;
private InfoStream infoStream = InfoStream.getDefault();
private boolean frozen;
public HnswConcurrentMergeBuilder(
TaskExecutor taskExecutor,
@ -69,6 +70,9 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
@Override
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (frozen) {
throw new IllegalStateException("graph has already been built");
}
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(
HNSW_COMPONENT,
@ -84,7 +88,8 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
});
}
taskExecutor.invokeAll(futures);
return workers[0].getGraph();
frozen = true;
return workers[0].getCompletedGraph();
}
@Override
@ -100,6 +105,12 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
}
}
@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
return getGraph();
}
@Override
public OnHeapHnswGraph getGraph() {
return workers[0].getGraph();

View File

@ -65,6 +65,7 @@ public class HnswGraphBuilder implements HnswBuilder {
protected final OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault();
private boolean frozen;
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
@ -152,11 +153,14 @@ public class HnswGraphBuilder implements HnswBuilder {
@Override
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (frozen) {
throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
}
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
}
addVectors(maxOrd);
return hnsw;
return getCompletedGraph();
}
@Override
@ -164,6 +168,12 @@ public class HnswGraphBuilder implements HnswBuilder {
this.infoStream = infoStream;
}
@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
return getGraph();
}
@Override
public OnHeapHnswGraph getGraph() {
return hnsw;
@ -171,6 +181,9 @@ public class HnswGraphBuilder implements HnswBuilder {
/** add vectors in range [minOrd, maxOrd) */
protected void addVectors(int minOrd, int maxOrd) throws IOException {
if (frozen) {
throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
}
long start = System.nanoTime(), t = start;
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
@ -207,6 +220,9 @@ public class HnswGraphBuilder implements HnswBuilder {
to the newly introduced levels (repeating step 2,3 for new levels) and again try to
promote the node to entry node.
*/
if (frozen) {
throw new IllegalStateException("Graph builder is already frozen");
}
RandomVectorScorer scorer = scorerSupplier.scorer(node);
final int nodeLevel = getRandomGraphLevel(ml, random);
// first add nodes to all levels

View File

@ -218,6 +218,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.size());
expectThrows(IllegalStateException.class, () -> builder.addGraphNode(0));
// Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed;
@ -1014,13 +1015,15 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
builder.setBatchSize(100);
builder.build(size);
exec.shutdownNow();
OnHeapHnswGraph graph = builder.getGraph();
OnHeapHnswGraph graph = builder.getCompletedGraph();
assertTrue(graph.entryNode() != -1);
assertEquals(size, graph.size());
assertEquals(size - 1, graph.maxNodeId());
for (int l = 0; l < graph.numLevels(); l++) {
assertNotNull(graph.getNodesOnLevel(l));
}
// cannot build twice
expectThrows(IllegalStateException.class, () -> builder.build(size));
}
public void testAllNodesVisitedInSingleLevel() throws IOException {