Add HnswGraphBuilder.getCompletedGraph() and record completed state (#13561)

This commit is contained in:
Michael Sokolov 2024-07-11 11:44:49 -04:00 committed by GitHub
parent cc854f408f
commit 8d1e624a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 4 deletions

View File

@ -615,7 +615,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
OnHeapHnswGraph getGraph() { OnHeapHnswGraph getGraph() {
assert flatFieldVectorsWriter.isFinished(); assert flatFieldVectorsWriter.isFinished();
if (node > 0) { if (node > 0) {
return hnswGraphBuilder.getGraph(); return hnswGraphBuilder.getCompletedGraph();
} else { } else {
return null; return null;
} }

View File

@ -41,4 +41,12 @@ public interface HnswBuilder {
void setInfoStream(InfoStream infoStream); void setInfoStream(InfoStream infoStream);
OnHeapHnswGraph getGraph(); OnHeapHnswGraph getGraph();
/**
* Once this method is called no further updates to the graph are accepted (addGraphNode will
* throw IllegalStateException). Final modifications to the graph (eg patching up disconnected
* components, re-ordering node ids for better delta compression) may be triggered, so callers
* should expect this call to take some time.
*/
OnHeapHnswGraph getCompletedGraph();
} }

View File

@ -41,6 +41,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
private final TaskExecutor taskExecutor; private final TaskExecutor taskExecutor;
private final ConcurrentMergeWorker[] workers; private final ConcurrentMergeWorker[] workers;
private InfoStream infoStream = InfoStream.getDefault(); private InfoStream infoStream = InfoStream.getDefault();
private boolean frozen;
public HnswConcurrentMergeBuilder( public HnswConcurrentMergeBuilder(
TaskExecutor taskExecutor, TaskExecutor taskExecutor,
@ -69,6 +70,9 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
@Override @Override
public OnHeapHnswGraph build(int maxOrd) throws IOException { public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (frozen) {
throw new IllegalStateException("graph has already been built");
}
if (infoStream.isEnabled(HNSW_COMPONENT)) { if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message( infoStream.message(
HNSW_COMPONENT, HNSW_COMPONENT,
@ -84,7 +88,8 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
}); });
} }
taskExecutor.invokeAll(futures); taskExecutor.invokeAll(futures);
return workers[0].getGraph(); frozen = true;
return workers[0].getCompletedGraph();
} }
@Override @Override
@ -100,6 +105,12 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
} }
} }
@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
return getGraph();
}
@Override @Override
public OnHeapHnswGraph getGraph() { public OnHeapHnswGraph getGraph() {
return workers[0].getGraph(); return workers[0].getGraph();

View File

@ -65,6 +65,7 @@ public class HnswGraphBuilder implements HnswBuilder {
protected final OnHeapHnswGraph hnsw; protected final OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault(); private InfoStream infoStream = InfoStream.getDefault();
private boolean frozen;
public static HnswGraphBuilder create( public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
@ -152,11 +153,14 @@ public class HnswGraphBuilder implements HnswBuilder {
@Override @Override
public OnHeapHnswGraph build(int maxOrd) throws IOException { public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (frozen) {
throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
}
if (infoStream.isEnabled(HNSW_COMPONENT)) { if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors"); infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
} }
addVectors(maxOrd); addVectors(maxOrd);
return hnsw; return getCompletedGraph();
} }
@Override @Override
@ -164,6 +168,12 @@ public class HnswGraphBuilder implements HnswBuilder {
this.infoStream = infoStream; this.infoStream = infoStream;
} }
@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
return getGraph();
}
@Override @Override
public OnHeapHnswGraph getGraph() { public OnHeapHnswGraph getGraph() {
return hnsw; return hnsw;
@ -171,6 +181,9 @@ public class HnswGraphBuilder implements HnswBuilder {
/** add vectors in range [minOrd, maxOrd) */ /** add vectors in range [minOrd, maxOrd) */
protected void addVectors(int minOrd, int maxOrd) throws IOException { protected void addVectors(int minOrd, int maxOrd) throws IOException {
if (frozen) {
throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
}
long start = System.nanoTime(), t = start; long start = System.nanoTime(), t = start;
if (infoStream.isEnabled(HNSW_COMPONENT)) { if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")"); infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
@ -207,6 +220,9 @@ public class HnswGraphBuilder implements HnswBuilder {
to the newly introduced levels (repeating step 2,3 for new levels) and again try to to the newly introduced levels (repeating step 2,3 for new levels) and again try to
promote the node to entry node. promote the node to entry node.
*/ */
if (frozen) {
throw new IllegalStateException("Graph builder is already frozen");
}
RandomVectorScorer scorer = scorerSupplier.scorer(node); RandomVectorScorer scorer = scorerSupplier.scorer(node);
final int nodeLevel = getRandomGraphLevel(ml, random); final int nodeLevel = getRandomGraphLevel(ml, random);
// first add nodes to all levels // first add nodes to all levels

View File

@ -218,6 +218,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.size()); HnswGraph hnsw = builder.build(vectors.size());
expectThrows(IllegalStateException.class, () -> builder.addGraphNode(0));
// Recreate the graph while indexing with the same random seed and write it out // Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed; HnswGraphBuilder.randSeed = seed;
@ -1014,13 +1015,15 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
builder.setBatchSize(100); builder.setBatchSize(100);
builder.build(size); builder.build(size);
exec.shutdownNow(); exec.shutdownNow();
OnHeapHnswGraph graph = builder.getGraph(); OnHeapHnswGraph graph = builder.getCompletedGraph();
assertTrue(graph.entryNode() != -1); assertTrue(graph.entryNode() != -1);
assertEquals(size, graph.size()); assertEquals(size, graph.size());
assertEquals(size - 1, graph.maxNodeId()); assertEquals(size - 1, graph.maxNodeId());
for (int l = 0; l < graph.numLevels(); l++) { for (int l = 0; l < graph.numLevels(); l++) {
assertNotNull(graph.getNodesOnLevel(l)); assertNotNull(graph.getNodesOnLevel(l));
} }
// cannot build twice
expectThrows(IllegalStateException.class, () -> builder.build(size));
} }
public void testAllNodesVisitedInSingleLevel() throws IOException { public void testAllNodesVisitedInSingleLevel() throws IOException {