mirror of https://github.com/apache/lucene.git
Add HnswGraphBuilder.getCompletedGraph() and record completed state (#13561)
This commit is contained in:
parent
cc854f408f
commit
8d1e624a67
|
@ -615,7 +615,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
OnHeapHnswGraph getGraph() {
|
OnHeapHnswGraph getGraph() {
|
||||||
assert flatFieldVectorsWriter.isFinished();
|
assert flatFieldVectorsWriter.isFinished();
|
||||||
if (node > 0) {
|
if (node > 0) {
|
||||||
return hnswGraphBuilder.getGraph();
|
return hnswGraphBuilder.getCompletedGraph();
|
||||||
} else {
|
} else {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,4 +41,12 @@ public interface HnswBuilder {
|
||||||
void setInfoStream(InfoStream infoStream);
|
void setInfoStream(InfoStream infoStream);
|
||||||
|
|
||||||
OnHeapHnswGraph getGraph();
|
OnHeapHnswGraph getGraph();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Once this method is called no further updates to the graph are accepted (addGraphNode will
|
||||||
|
* throw IllegalStateException). Final modifications to the graph (eg patching up disconnected
|
||||||
|
* components, re-ordering node ids for better delta compression) may be triggered, so callers
|
||||||
|
* should expect this call to take some time.
|
||||||
|
*/
|
||||||
|
OnHeapHnswGraph getCompletedGraph();
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
||||||
private final TaskExecutor taskExecutor;
|
private final TaskExecutor taskExecutor;
|
||||||
private final ConcurrentMergeWorker[] workers;
|
private final ConcurrentMergeWorker[] workers;
|
||||||
private InfoStream infoStream = InfoStream.getDefault();
|
private InfoStream infoStream = InfoStream.getDefault();
|
||||||
|
private boolean frozen;
|
||||||
|
|
||||||
public HnswConcurrentMergeBuilder(
|
public HnswConcurrentMergeBuilder(
|
||||||
TaskExecutor taskExecutor,
|
TaskExecutor taskExecutor,
|
||||||
|
@ -69,6 +70,9 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnHeapHnswGraph build(int maxOrd) throws IOException {
|
public OnHeapHnswGraph build(int maxOrd) throws IOException {
|
||||||
|
if (frozen) {
|
||||||
|
throw new IllegalStateException("graph has already been built");
|
||||||
|
}
|
||||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||||
infoStream.message(
|
infoStream.message(
|
||||||
HNSW_COMPONENT,
|
HNSW_COMPONENT,
|
||||||
|
@ -84,7 +88,8 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
taskExecutor.invokeAll(futures);
|
taskExecutor.invokeAll(futures);
|
||||||
return workers[0].getGraph();
|
frozen = true;
|
||||||
|
return workers[0].getCompletedGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -100,6 +105,12 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OnHeapHnswGraph getCompletedGraph() {
|
||||||
|
frozen = true;
|
||||||
|
return getGraph();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnHeapHnswGraph getGraph() {
|
public OnHeapHnswGraph getGraph() {
|
||||||
return workers[0].getGraph();
|
return workers[0].getGraph();
|
||||||
|
|
|
@ -65,6 +65,7 @@ public class HnswGraphBuilder implements HnswBuilder {
|
||||||
protected final OnHeapHnswGraph hnsw;
|
protected final OnHeapHnswGraph hnsw;
|
||||||
|
|
||||||
private InfoStream infoStream = InfoStream.getDefault();
|
private InfoStream infoStream = InfoStream.getDefault();
|
||||||
|
private boolean frozen;
|
||||||
|
|
||||||
public static HnswGraphBuilder create(
|
public static HnswGraphBuilder create(
|
||||||
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
|
||||||
|
@ -152,11 +153,14 @@ public class HnswGraphBuilder implements HnswBuilder {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnHeapHnswGraph build(int maxOrd) throws IOException {
|
public OnHeapHnswGraph build(int maxOrd) throws IOException {
|
||||||
|
if (frozen) {
|
||||||
|
throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
|
||||||
|
}
|
||||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||||
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
|
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
|
||||||
}
|
}
|
||||||
addVectors(maxOrd);
|
addVectors(maxOrd);
|
||||||
return hnsw;
|
return getCompletedGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -164,6 +168,12 @@ public class HnswGraphBuilder implements HnswBuilder {
|
||||||
this.infoStream = infoStream;
|
this.infoStream = infoStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OnHeapHnswGraph getCompletedGraph() {
|
||||||
|
frozen = true;
|
||||||
|
return getGraph();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnHeapHnswGraph getGraph() {
|
public OnHeapHnswGraph getGraph() {
|
||||||
return hnsw;
|
return hnsw;
|
||||||
|
@ -171,6 +181,9 @@ public class HnswGraphBuilder implements HnswBuilder {
|
||||||
|
|
||||||
/** add vectors in range [minOrd, maxOrd) */
|
/** add vectors in range [minOrd, maxOrd) */
|
||||||
protected void addVectors(int minOrd, int maxOrd) throws IOException {
|
protected void addVectors(int minOrd, int maxOrd) throws IOException {
|
||||||
|
if (frozen) {
|
||||||
|
throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
|
||||||
|
}
|
||||||
long start = System.nanoTime(), t = start;
|
long start = System.nanoTime(), t = start;
|
||||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||||
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
|
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
|
||||||
|
@ -207,6 +220,9 @@ public class HnswGraphBuilder implements HnswBuilder {
|
||||||
to the newly introduced levels (repeating step 2,3 for new levels) and again try to
|
to the newly introduced levels (repeating step 2,3 for new levels) and again try to
|
||||||
promote the node to entry node.
|
promote the node to entry node.
|
||||||
*/
|
*/
|
||||||
|
if (frozen) {
|
||||||
|
throw new IllegalStateException("Graph builder is already frozen");
|
||||||
|
}
|
||||||
RandomVectorScorer scorer = scorerSupplier.scorer(node);
|
RandomVectorScorer scorer = scorerSupplier.scorer(node);
|
||||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||||
// first add nodes to all levels
|
// first add nodes to all levels
|
||||||
|
|
|
@ -218,6 +218,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
|
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
|
||||||
HnswGraph hnsw = builder.build(vectors.size());
|
HnswGraph hnsw = builder.build(vectors.size());
|
||||||
|
expectThrows(IllegalStateException.class, () -> builder.addGraphNode(0));
|
||||||
|
|
||||||
// Recreate the graph while indexing with the same random seed and write it out
|
// Recreate the graph while indexing with the same random seed and write it out
|
||||||
HnswGraphBuilder.randSeed = seed;
|
HnswGraphBuilder.randSeed = seed;
|
||||||
|
@ -1014,13 +1015,15 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
builder.setBatchSize(100);
|
builder.setBatchSize(100);
|
||||||
builder.build(size);
|
builder.build(size);
|
||||||
exec.shutdownNow();
|
exec.shutdownNow();
|
||||||
OnHeapHnswGraph graph = builder.getGraph();
|
OnHeapHnswGraph graph = builder.getCompletedGraph();
|
||||||
assertTrue(graph.entryNode() != -1);
|
assertTrue(graph.entryNode() != -1);
|
||||||
assertEquals(size, graph.size());
|
assertEquals(size, graph.size());
|
||||||
assertEquals(size - 1, graph.maxNodeId());
|
assertEquals(size - 1, graph.maxNodeId());
|
||||||
for (int l = 0; l < graph.numLevels(); l++) {
|
for (int l = 0; l < graph.numLevels(); l++) {
|
||||||
assertNotNull(graph.getNodesOnLevel(l));
|
assertNotNull(graph.getNodesOnLevel(l));
|
||||||
}
|
}
|
||||||
|
// cannot build twice
|
||||||
|
expectThrows(IllegalStateException.class, () -> builder.build(size));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testAllNodesVisitedInSingleLevel() throws IOException {
|
public void testAllNodesVisitedInSingleLevel() throws IOException {
|
||||||
|
|
Loading…
Reference in New Issue