mirror of https://github.com/apache/lucene.git
Reuse HNSW graph for intialization during merge (#12050)
* Remove implicit addition of vector 0 Removes logic to add 0 vector implicitly. This is in preparation for adding nodes from other graphs to initialize a new graph. Having the implicit addition of node 0 complicates this logic. Signed-off-by: John Mazanec <jmazane@amazon.com> * Enable out of order insertion of nodes in hnsw Enables nodes to be added into OnHeapHnswGraph in out of order fashion. To do so, additional operations have to be taken to resort the nodesByLevel array. Optimizations have been made to avoid sorting whenever possible. Signed-off-by: John Mazanec <jmazane@amazon.com> * Add ability to initialize from graph Adds method to initialize an HNSWGraphBuilder from another HNSWGraph. Initialization can only happen when the builder's graph is empty. Signed-off-by: John Mazanec <jmazane@amazon.com> * Utilize merge with graph init in HNSWWriter Uses HNSWGraphBuilder initialization from graph functionality in Lucene95HnswVectorsWriter. Selects the largest graph to initialize the new graph produced by the HNSWGraphBuilder for merge. Signed-off-by: John Mazanec <jmazane@amazon.com> * Minor modifications to Lucene95HnswVectorsWriter Signed-off-by: John Mazanec <jmazane@amazon.com> * Use TreeMap for graph structure for levels > 0 Refactors OnHeapHnswGraph to use TreeMap to represent graph structure of levels greater than 0. Refactors NodesIterator to support set representation of nodes. Signed-off-by: John Mazanec <jmazane@amazon.com> * Refactor initializer to be in static create method Refeactors initialization from graph to be accessible via a create static method in HnswGraphBuilder. Signed-off-by: John Mazanec <jmazane@amazon.com> * Address review comments Signed-off-by: John Mazanec <jmazane@amazon.com> * Add change log entry Signed-off-by: John Mazanec <jmazane@amazon.com> * Remove empty iterator for neighborqueue Signed-off-by: John Mazanec <jmazane@amazon.com> --------- Signed-off-by: John Mazanec <jmazane@amazon.com>
This commit is contained in:
parent
ab074d5483
commit
776149f0f6
|
@ -133,6 +133,8 @@ Optimizations
|
|||
|
||||
* GITHUB#12128, GITHUB#12133: Speed up docvalues set query by making use of sortedness. (Robert Muir, Uwe Schindler)
|
||||
|
||||
* GITHUB#12050: Reuse HNSW graph for intialization during merge (Jack Mazanec)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
(No changes)
|
||||
|
|
|
@ -561,9 +561,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,9 +163,9 @@ public final class Lucene91OnHeapHnswGraph extends HnswGraph {
|
|||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -457,9 +457,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -533,9 +533,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -345,7 +345,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
if (level == 0) {
|
||||
return graph.getNodesOnLevel(0);
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
|
||||
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -687,10 +687,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(copyValue(vectorValue));
|
||||
if (node > 0) {
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||
}
|
||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||
node++;
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
|
|
@ -573,9 +573,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,11 +25,16 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.*;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -357,7 +362,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
if (level == 0) {
|
||||
return graph.getNodesOnLevel(0);
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
|
||||
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -424,6 +429,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
int[][] vectorIndexNodeOffsets = null;
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
// build graph
|
||||
int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
|
||||
graph =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
|
@ -434,13 +440,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<byte[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||
}
|
||||
|
@ -452,13 +452,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||
}
|
||||
|
@ -489,6 +483,189 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private <T> HnswGraphBuilder<T> createHnswGraphBuilder(
|
||||
MergeState mergeState,
|
||||
FieldInfo fieldInfo,
|
||||
RandomAccessVectorValues<T> floatVectorValues,
|
||||
int initializerIndex)
|
||||
throws IOException {
|
||||
if (initializerIndex == -1) {
|
||||
return HnswGraphBuilder.create(
|
||||
floatVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
}
|
||||
|
||||
HnswGraph initializerGraph =
|
||||
getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
|
||||
Map<Integer, Integer> ordinalMapper =
|
||||
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
|
||||
return HnswGraphBuilder.create(
|
||||
floatVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed,
|
||||
initializerGraph,
|
||||
ordinalMapper);
|
||||
}
|
||||
|
||||
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
|
||||
throws IOException {
|
||||
// Find the KnnVectorReader with the most docs that meets the following criteria:
|
||||
// 1. Does not contain any deleted docs
|
||||
// 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
|
||||
// If no readers exist that meet this criteria, return -1. If they do, return their index in
|
||||
// merge state
|
||||
int maxCandidateVectorCount = 0;
|
||||
int initializerIndex = -1;
|
||||
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
if (mergeState.knnVectorsReaders[i]
|
||||
instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
|
||||
}
|
||||
|
||||
if (!allMatch(mergeState.liveDocs[i])
|
||||
|| !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int candidateVectorCount = 0;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
|
||||
if (byteVectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
candidateVectorCount = byteVectorValues.size();
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
candidateVectorCount = vectorValues.size();
|
||||
}
|
||||
}
|
||||
|
||||
if (candidateVectorCount > maxCandidateVectorCount) {
|
||||
maxCandidateVectorCount = candidateVectorCount;
|
||||
initializerIndex = i;
|
||||
}
|
||||
}
|
||||
return initializerIndex;
|
||||
}
|
||||
|
||||
private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
|
||||
&& perFieldReader.getFieldReader(fieldName)
|
||||
instanceof Lucene95HnswVectorsReader fieldReader) {
|
||||
return fieldReader.getGraph(fieldName);
|
||||
}
|
||||
|
||||
if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
|
||||
return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
|
||||
}
|
||||
|
||||
// We should not reach here because knnVectorsReader's type is checked in
|
||||
// selectGraphForInitialization
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid KnnVectorsReader type for field: "
|
||||
+ fieldName
|
||||
+ ". Must be Lucene95HnswVectorsReader or newer");
|
||||
}
|
||||
|
||||
private Map<Integer, Integer> getOldToNewOrdinalMap(
|
||||
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {
|
||||
|
||||
DocIdSetIterator initializerIterator = null;
|
||||
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> initializerIterator =
|
||||
mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
|
||||
case FLOAT32 -> initializerIterator =
|
||||
mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
|
||||
}
|
||||
|
||||
MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];
|
||||
|
||||
Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
|
||||
int oldOrd = 0;
|
||||
int maxNewDocID = -1;
|
||||
for (int oldId = initializerIterator.nextDoc();
|
||||
oldId != NO_MORE_DOCS;
|
||||
oldId = initializerIterator.nextDoc()) {
|
||||
if (isCurrentVectorNull(initializerIterator)) {
|
||||
continue;
|
||||
}
|
||||
int newId = initializerDocMap.get(oldId);
|
||||
maxNewDocID = Math.max(newId, maxNewDocID);
|
||||
newIdToOldOrdinal.put(newId, oldOrd);
|
||||
oldOrd++;
|
||||
}
|
||||
|
||||
if (maxNewDocID == -1) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
|
||||
|
||||
DocIdSetIterator vectorIterator = null;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
case FLOAT32 -> vectorIterator =
|
||||
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
int newOrd = 0;
|
||||
for (int newDocId = vectorIterator.nextDoc();
|
||||
newDocId <= maxNewDocID;
|
||||
newDocId = vectorIterator.nextDoc()) {
|
||||
if (isCurrentVectorNull(vectorIterator)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (newIdToOldOrdinal.containsKey(newDocId)) {
|
||||
oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
|
||||
}
|
||||
newOrd++;
|
||||
}
|
||||
|
||||
return oldToNewOrdinalMap;
|
||||
}
|
||||
|
||||
private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
|
||||
if (docIdSetIterator instanceof FloatVectorValues) {
|
||||
return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
|
||||
}
|
||||
|
||||
if (docIdSetIterator instanceof ByteVectorValues) {
|
||||
return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean allMatch(Bits bits) {
|
||||
if (bits == null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < bits.length(); i++) {
|
||||
if (!bits.get(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param graph Write the graph in a compressed format
|
||||
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
|
||||
|
@ -735,10 +912,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(copyValue(vectorValue));
|
||||
if (node > 0) {
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||
}
|
||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||
node++;
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.PrimitiveIterator;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
|
@ -115,7 +117,7 @@ public abstract class HnswGraph {
|
|||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
return NodesIterator.EMPTY;
|
||||
return ArrayNodesIterator.EMPTY;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -123,33 +125,50 @@ public abstract class HnswGraph {
|
|||
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
|
||||
* number of nodes to be iterated over.
|
||||
*/
|
||||
public static final class NodesIterator implements PrimitiveIterator.OfInt {
|
||||
static NodesIterator EMPTY = new NodesIterator(0);
|
||||
|
||||
private final int[] nodes;
|
||||
private final int size;
|
||||
int cur = 0;
|
||||
|
||||
/** Constructor for iterator based on the nodes array up to the size */
|
||||
public NodesIterator(int[] nodes, int size) {
|
||||
assert nodes != null;
|
||||
assert size <= nodes.length;
|
||||
this.nodes = nodes;
|
||||
this.size = size;
|
||||
}
|
||||
public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
|
||||
protected final int size;
|
||||
|
||||
/** Constructor for iterator based on the size */
|
||||
public NodesIterator(int size) {
|
||||
this.nodes = null;
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
/** The number of elements in this iterator * */
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume integers from the iterator and place them into the `dest` array.
|
||||
*
|
||||
* @param dest where to put the integers
|
||||
* @return The number of integers written to `dest`
|
||||
*/
|
||||
public abstract int consume(int[] dest);
|
||||
}
|
||||
|
||||
/** NodesIterator that accepts nodes as an integer array. */
|
||||
public static class ArrayNodesIterator extends NodesIterator {
|
||||
static NodesIterator EMPTY = new ArrayNodesIterator(0);
|
||||
|
||||
private final int[] nodes;
|
||||
private int cur = 0;
|
||||
|
||||
/** Constructor for iterator based on integer array representing nodes */
|
||||
public ArrayNodesIterator(int[] nodes, int size) {
|
||||
super(size);
|
||||
assert nodes != null;
|
||||
assert size <= nodes.length;
|
||||
this.nodes = nodes;
|
||||
}
|
||||
|
||||
/** Constructor for iterator based on the size */
|
||||
public ArrayNodesIterator(int size) {
|
||||
super(size);
|
||||
this.nodes = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int consume(int[] dest) {
|
||||
if (hasNext() == false) {
|
||||
throw new NoSuchElementException();
|
||||
|
@ -182,10 +201,43 @@ public abstract class HnswGraph {
|
|||
public boolean hasNext() {
|
||||
return cur < size;
|
||||
}
|
||||
}
|
||||
|
||||
/** The number of elements in this iterator * */
|
||||
public int size() {
|
||||
return size;
|
||||
/** Nodes iterator based on set representation of nodes. */
|
||||
public static class CollectionNodesIterator extends NodesIterator {
|
||||
Iterator<Integer> nodes;
|
||||
|
||||
/** Constructor for iterator based on collection representing nodes */
|
||||
public CollectionNodesIterator(Collection<Integer> nodes) {
|
||||
super(nodes.size());
|
||||
this.nodes = nodes.iterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int consume(int[] dest) {
|
||||
if (hasNext() == false) {
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
|
||||
int destIndex = 0;
|
||||
while (hasNext() && destIndex < dest.length) {
|
||||
dest[destIndex++] = nextInt();
|
||||
}
|
||||
|
||||
return destIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
if (hasNext() == false) {
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
return nodes.next();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return nodes.hasNext();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,14 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static java.lang.Math.log;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.SplittableRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
|
@ -63,6 +67,7 @@ public final class HnswGraphBuilder<T> {
|
|||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private final RandomAccessVectorValues<T> vectorsCopy;
|
||||
private final Set<Integer> initializedNodes;
|
||||
|
||||
public static <T> HnswGraphBuilder<T> create(
|
||||
RandomAccessVectorValues<T> vectors,
|
||||
|
@ -75,6 +80,22 @@ public final class HnswGraphBuilder<T> {
|
|||
return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||
}
|
||||
|
||||
public static <T> HnswGraphBuilder<T> create(
|
||||
RandomAccessVectorValues<T> vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed,
|
||||
HnswGraph initializerGraph,
|
||||
Map<Integer, Integer> oldToNewOrdinalMap)
|
||||
throws IOException {
|
||||
HnswGraphBuilder<T> hnswGraphBuilder =
|
||||
new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
|
||||
return hnswGraphBuilder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from vector values, builds a graph connecting them by their dense
|
||||
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
|
||||
|
@ -110,8 +131,7 @@ public final class HnswGraphBuilder<T> {
|
|||
// normalization factor for level generation; currently not configurable
|
||||
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
|
||||
this.random = new SplittableRandom(seed);
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
|
||||
this.hnsw = new OnHeapHnswGraph(M);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
|
@ -120,6 +140,7 @@ public final class HnswGraphBuilder<T> {
|
|||
new FixedBitSet(this.vectors.size()));
|
||||
// in scratch we store candidates in reverse order: worse candidates are first
|
||||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
|
||||
this.initializedNodes = new HashSet<>();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -142,10 +163,64 @@ public final class HnswGraphBuilder<T> {
|
|||
return hnsw;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the graph of this builder. Transfers the nodes and their neighbors from the
|
||||
* initializer graph into the graph being produced by this builder, mapping ordinals from the
|
||||
* initializer graph to their new ordinals in this builder's graph. The builder's graph must be
|
||||
* empty before calling this method.
|
||||
*
|
||||
* @param initializerGraph graph used for initialization
|
||||
* @param oldToNewOrdinalMap map for converting from ordinals in the initializerGraph to this
|
||||
* builder's graph
|
||||
*/
|
||||
private void initializeFromGraph(
|
||||
HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
|
||||
assert hnsw.size() == 0;
|
||||
float[] vectorValue = null;
|
||||
byte[] binaryValue = null;
|
||||
for (int level = 0; level < initializerGraph.numLevels(); level++) {
|
||||
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
|
||||
|
||||
while (it.hasNext()) {
|
||||
int oldOrd = it.nextInt();
|
||||
int newOrd = oldToNewOrdinalMap.get(oldOrd);
|
||||
|
||||
hnsw.addNode(level, newOrd);
|
||||
|
||||
if (level == 0) {
|
||||
initializedNodes.add(newOrd);
|
||||
}
|
||||
|
||||
switch (this.vectorEncoding) {
|
||||
case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(newOrd);
|
||||
case BYTE -> binaryValue = (byte[]) vectors.vectorValue(newOrd);
|
||||
}
|
||||
|
||||
NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
|
||||
initializerGraph.seek(level, oldOrd);
|
||||
for (int oldNeighbor = initializerGraph.nextNeighbor();
|
||||
oldNeighbor != NO_MORE_DOCS;
|
||||
oldNeighbor = initializerGraph.nextNeighbor()) {
|
||||
int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
|
||||
float score =
|
||||
switch (this.vectorEncoding) {
|
||||
case FLOAT32 -> this.similarityFunction.compare(
|
||||
vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor));
|
||||
case BYTE -> this.similarityFunction.compare(
|
||||
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
|
||||
};
|
||||
newNeighbors.insertSorted(newNeighbor, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
|
||||
long start = System.nanoTime(), t = start;
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
for (int node = 1; node < vectorsToAdd.size(); node++) {
|
||||
for (int node = 0; node < vectorsToAdd.size(); node++) {
|
||||
if (initializedNodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
addGraphNode(node, vectorsToAdd);
|
||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
t = printGraphBuildStatus(node, start, t);
|
||||
|
@ -167,6 +242,14 @@ public final class HnswGraphBuilder<T> {
|
|||
NeighborQueue candidates;
|
||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||
int curMaxLevel = hnsw.numLevels() - 1;
|
||||
|
||||
// If entrynode is -1, then this should finish without adding neighbors
|
||||
if (hnsw.entryNode() == -1) {
|
||||
for (int level = nodeLevel; level >= 0; level--) {
|
||||
hnsw.addNode(level, node);
|
||||
}
|
||||
return;
|
||||
}
|
||||
int[] eps = new int[] {hnsw.entryNode()};
|
||||
|
||||
// if a node introduces new levels to the graph, add this new node on new levels
|
||||
|
|
|
@ -101,7 +101,12 @@ public class HnswGraphSearcher<T> {
|
|||
new NeighborQueue(topK, true),
|
||||
new SparseFixedBitSet(vectors.size()));
|
||||
NeighborQueue results;
|
||||
int[] eps = new int[] {graph.entryNode()};
|
||||
|
||||
int initialEp = graph.entryNode();
|
||||
if (initialEp == -1) {
|
||||
return new NeighborQueue(1, true);
|
||||
}
|
||||
int[] eps = new int[] {initialEp};
|
||||
int numVisited = 0;
|
||||
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
|
||||
|
|
|
@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.TreeMap;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
/**
|
||||
|
@ -33,19 +32,20 @@ import org.apache.lucene.util.RamUsageEstimator;
|
|||
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
||||
|
||||
private int numLevels; // the current number of levels in the graph
|
||||
private int entryNode; // the current graph entry node on the top level
|
||||
private int entryNode; // the current graph entry node on the top level. -1 if not set
|
||||
|
||||
// Nodes by level expressed as the level 0's nodes' ordinals.
|
||||
// As level 0 contains all nodes, nodesByLevel.get(0) is null.
|
||||
private final List<int[]> nodesByLevel;
|
||||
|
||||
// graph is a list of graph levels.
|
||||
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
|
||||
// Level 0 is represented as List<NeighborArray> – nodes' connections on level 0.
|
||||
// Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
|
||||
// to vectors
|
||||
// added to HnswBuilder, and the node values are the ordinals of those vectors.
|
||||
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
|
||||
private final List<List<NeighborArray>> graph;
|
||||
private final List<NeighborArray> graphLevel0;
|
||||
// Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0
|
||||
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
|
||||
// it in this list. However, to avoid changing list indexing, we always will make the first
|
||||
// element
|
||||
// null.
|
||||
private final List<TreeMap<Integer, NeighborArray>> graphUpperLevels;
|
||||
private final int nsize;
|
||||
private final int nsize0;
|
||||
|
||||
|
@ -53,24 +53,17 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
private int upto;
|
||||
private NeighborArray cur;
|
||||
|
||||
OnHeapHnswGraph(int M, int levelOfFirstNode) {
|
||||
this.numLevels = levelOfFirstNode + 1;
|
||||
this.graph = new ArrayList<>(numLevels);
|
||||
this.entryNode = 0;
|
||||
OnHeapHnswGraph(int M) {
|
||||
this.numLevels = 1; // Implicitly start the graph with a single level
|
||||
this.graphLevel0 = new ArrayList<>();
|
||||
this.entryNode = -1; // Entry node should be negative until a node is added
|
||||
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
|
||||
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
|
||||
this.nsize = M + 1;
|
||||
this.nsize0 = (M * 2 + 1);
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
graph.add(new ArrayList<>());
|
||||
graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true));
|
||||
}
|
||||
|
||||
this.nodesByLevel = new ArrayList<>(numLevels);
|
||||
nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||
for (int l = 1; l < numLevels; l++) {
|
||||
nodesByLevel.add(new int[] {0});
|
||||
}
|
||||
this.graphUpperLevels = new ArrayList<>(numLevels);
|
||||
graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -81,49 +74,52 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
*/
|
||||
public NeighborArray getNeighbors(int level, int node) {
|
||||
if (level == 0) {
|
||||
return graph.get(level).get(node);
|
||||
return graphLevel0.get(node);
|
||||
}
|
||||
int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
|
||||
assert nodeIndex >= 0;
|
||||
return graph.get(level).get(nodeIndex);
|
||||
TreeMap<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
|
||||
assert levelMap.containsKey(node);
|
||||
return levelMap.get(node);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graph.get(0).size(); // all nodes are located on the 0th level
|
||||
return graphLevel0.size(); // all nodes are located on the 0th level
|
||||
}
|
||||
|
||||
/**
|
||||
* Add node on the given level
|
||||
* Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
|
||||
* preceded by the node inserted out of order are eventually added.
|
||||
*
|
||||
* @param level level to add a node on
|
||||
* @param node the node to add, represented as an ordinal on the level 0.
|
||||
*/
|
||||
public void addNode(int level, int node) {
|
||||
if (entryNode == -1) {
|
||||
entryNode = node;
|
||||
}
|
||||
|
||||
if (level > 0) {
|
||||
// if the new node introduces a new level, add more levels to the graph,
|
||||
// and make this node the graph's new entry point
|
||||
if (level >= numLevels) {
|
||||
for (int i = numLevels; i <= level; i++) {
|
||||
graph.add(new ArrayList<>());
|
||||
nodesByLevel.add(new int[] {node});
|
||||
graphUpperLevels.add(new TreeMap<>());
|
||||
}
|
||||
numLevels = level + 1;
|
||||
entryNode = node;
|
||||
} else {
|
||||
// Add this node id to this level's nodes
|
||||
int[] nodes = nodesByLevel.get(level);
|
||||
int idx = graph.get(level).size();
|
||||
if (idx < nodes.length) {
|
||||
nodes[idx] = node;
|
||||
} else {
|
||||
nodes = ArrayUtil.grow(nodes);
|
||||
nodes[idx] = node;
|
||||
nodesByLevel.set(level, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true));
|
||||
} else {
|
||||
// Add nodes all the way up to and including "node" in the new graph on level 0. This will
|
||||
// cause the size of the
|
||||
// graph to differ from the number of nodes added to the graph. The size of the graph and the
|
||||
// number of nodes
|
||||
// added will only be in sync once all nodes from 0...last_node are added into the graph.
|
||||
while (node >= graphLevel0.size()) {
|
||||
graphLevel0.add(new NeighborArray(nsize0, true));
|
||||
}
|
||||
}
|
||||
graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -164,9 +160,9 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new NodesIterator(size());
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
|
||||
return new CollectionNodesIterator(graphUpperLevels.get(level).keySet());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,19 +180,26 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
|
|||
+ Integer.BYTES * 2;
|
||||
long total = 0;
|
||||
for (int l = 0; l < numLevels; l++) {
|
||||
int numNodesOnLevel = graph.get(l).size();
|
||||
if (l == 0) {
|
||||
total +=
|
||||
numNodesOnLevel * neighborArrayBytes0
|
||||
graphLevel0.size() * neighborArrayBytes0
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
|
||||
} else {
|
||||
long numNodesOnLevel = graphUpperLevels.get(l).size();
|
||||
|
||||
// For levels > 0, we represent the graph structure with a tree map.
|
||||
// A single node in the tree contains 3 references (left root, right root, value) as well
|
||||
// as an Integer for the key and 1 extra byte for the color of the node (this is actually 1
|
||||
// bit, but
|
||||
// because we do not have that granularity, we set to 1 byte). In addition, we include 1
|
||||
// more reference for
|
||||
// the tree map itself.
|
||||
total +=
|
||||
nodesByLevel.get(l).length * Integer.BYTES
|
||||
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel
|
||||
total +=
|
||||
numNodesOnLevel * neighborArrayBytes
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
|
||||
numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1)
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
|
||||
// Add the size neighbor of each node
|
||||
total += numNodesOnLevel * neighborArrayBytes;
|
||||
}
|
||||
}
|
||||
return total;
|
||||
|
|
|
@ -48,14 +48,12 @@ import org.apache.lucene.search.SearcherManager;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -179,21 +177,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
|
||||
* and merging.
|
||||
*/
|
||||
public void testMergeProducesSameGraph() throws Exception {
|
||||
long seed = random().nextLong();
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
float[][] values = randomVectors(numDoc, dimension);
|
||||
int mergePoint = random().nextInt(numDoc);
|
||||
int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
|
||||
int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
|
||||
assertGraphEquals(singleSegmentGraph, mergedGraph);
|
||||
}
|
||||
|
||||
/** Test writing and reading of multiple vector fields * */
|
||||
public void testMultipleVectorFields() throws Exception {
|
||||
int numVectorFields = randomIntBetween(2, 5);
|
||||
|
@ -227,52 +210,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private void assertGraphEquals(int[][][] expected, int[][][] actual) {
|
||||
assertEquals("graph sizes differ", expected.length, actual.length);
|
||||
for (int level = 0; level < expected.length; level++) {
|
||||
for (int node = 0; node < expected[level].length; node++) {
|
||||
assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a naive representation of an HNSW graph as a 3 dimensional array: 1st dim represents a
|
||||
* graph layer. Each layer contains an array of arrays – a list of nodes and for each node a list
|
||||
* of the node's neighbours. 2nd dim represents a node on a layer, and contains the node's
|
||||
* neighbourhood, or {@code null} if a node is not present on this layer. 3rd dim represents
|
||||
* neighbours of a node.
|
||||
*/
|
||||
private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed)
|
||||
throws IOException {
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
int[][][] graph;
|
||||
try (Directory dir = newDirectory()) {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
|
||||
iwc.setCodec(codec); // don't use SimpleTextCodec
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
add(iw, i, values[i]);
|
||||
if (i == mergePoint) {
|
||||
// flush proactively to create a segment
|
||||
iw.flush();
|
||||
}
|
||||
}
|
||||
iw.forceMerge(1);
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
PerFieldKnnVectorsFormat.FieldsReader perFieldReader =
|
||||
(PerFieldKnnVectorsFormat.FieldsReader)
|
||||
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
|
||||
Lucene95HnswVectorsReader vectorReader =
|
||||
(Lucene95HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
|
||||
graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD));
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
private float[][] randomVectors(int numDoc, int dimension) {
|
||||
float[][] values = new float[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
|
@ -297,27 +234,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
return value;
|
||||
}
|
||||
|
||||
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
||||
int[][][] graph = new int[graphValues.numLevels()][][];
|
||||
int size = graphValues.size();
|
||||
int[] scratch = new int[M * 2];
|
||||
|
||||
for (int level = 0; level < graphValues.numLevels(); level++) {
|
||||
NodesIterator nodesItr = graphValues.getNodesOnLevel(level);
|
||||
graph[level] = new int[size][];
|
||||
while (nodesItr.hasNext()) {
|
||||
int node = nodesItr.nextInt();
|
||||
graphValues.seek(level, node);
|
||||
int n, count = 0;
|
||||
while ((n = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
scratch[count++] = n;
|
||||
}
|
||||
graph[level][node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
/** Verify that searching does something reasonable */
|
||||
public void testSearch() throws Exception {
|
||||
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
|
||||
|
|
|
@ -25,10 +25,14 @@ import com.carrotsearch.randomizedtesting.RandomizedTest;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95Codec;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
|
@ -84,6 +88,12 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException;
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(
|
||||
int size,
|
||||
int dimension,
|
||||
AbstractMockVectorValues<T> pregeneratedVectorValues,
|
||||
int pregeneratedOffset);
|
||||
|
||||
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
|
||||
|
||||
abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
|
||||
|
@ -427,6 +437,238 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException {
|
||||
int maxNumLevels = randomIntBetween(2, 10);
|
||||
int nodeCount = randomIntBetween(1, 100);
|
||||
|
||||
List<List<Integer>> nodesPerLevel = new ArrayList<>();
|
||||
for (int i = 0; i < maxNumLevels; i++) {
|
||||
nodesPerLevel.add(new ArrayList<>());
|
||||
}
|
||||
|
||||
int numLevels = 0;
|
||||
for (int currNode = 0; currNode < nodeCount; currNode++) {
|
||||
int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1);
|
||||
numLevels = Math.max(numLevels, nodeMaxLevel);
|
||||
for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) {
|
||||
nodesPerLevel.get(currLevel).add(currNode);
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
|
||||
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
|
||||
int currLevelNodesSize = currLevelNodes.size();
|
||||
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
|
||||
topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
|
||||
List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
|
||||
int currLevelNodesSize = currLevelNodes.size();
|
||||
for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
|
||||
bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
|
||||
List<Integer> currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel));
|
||||
Collections.shuffle(currLevelNodes, random());
|
||||
for (Integer currNode : currLevelNodes) {
|
||||
topDownOrderRandomHnsw.addNode(currLevel, currNode);
|
||||
}
|
||||
}
|
||||
|
||||
OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10);
|
||||
for (int currLevel = 0; currLevel < numLevels; currLevel++) {
|
||||
for (Integer currNode : nodesPerLevel.get(currLevel)) {
|
||||
bottomUpExpectedHnsw.addNode(currLevel, currNode);
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size());
|
||||
for (Integer node : nodesPerLevel.get(0)) {
|
||||
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size());
|
||||
}
|
||||
|
||||
for (int currLevel = 1; currLevel < numLevels; currLevel++) {
|
||||
NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel);
|
||||
List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
|
||||
assertEquals(expectedNodesOnLevel.size(), nodesIterator.size());
|
||||
for (Integer expectedNode : expectedNodesOnLevel) {
|
||||
int currentNode = nodesIterator.nextInt();
|
||||
assertEquals(expectedNode.intValue(), currentNode);
|
||||
assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size());
|
||||
}
|
||||
}
|
||||
|
||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
|
||||
assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw);
|
||||
assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw);
|
||||
}
|
||||
|
||||
public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
|
||||
int totalSize = atLeast(100);
|
||||
int initializerSize = random().nextInt(5, totalSize);
|
||||
int docIdOffset = 0;
|
||||
int dim = atLeast(10);
|
||||
long seed = random().nextLong();
|
||||
|
||||
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
|
||||
HnswGraphBuilder<T> initializerBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
initializerVectors, getVectorEncoding(), similarityFunction, 10, 30, seed);
|
||||
|
||||
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy());
|
||||
AbstractMockVectorValues<T> finalVectorValues =
|
||||
vectorValues(totalSize, dim, initializerVectors, docIdOffset);
|
||||
|
||||
Map<Integer, Integer> initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
|
||||
|
||||
HnswGraphBuilder<T> finalBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
finalVectorValues,
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
10,
|
||||
30,
|
||||
seed,
|
||||
initializerGraph,
|
||||
initializerOrdMap);
|
||||
|
||||
// When offset is 0, the graphs should be identical before vectors are added
|
||||
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
|
||||
|
||||
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy());
|
||||
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
|
||||
}
|
||||
|
||||
public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() throws IOException {
|
||||
int totalSize = atLeast(100);
|
||||
int initializerSize = random().nextInt(5, totalSize);
|
||||
int docIdOffset = random().nextInt(1, totalSize - initializerSize + 1);
|
||||
int dim = atLeast(10);
|
||||
long seed = random().nextLong();
|
||||
|
||||
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
|
||||
HnswGraphBuilder<T> initializerBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
initializerVectors.copy(), getVectorEncoding(), similarityFunction, 10, 30, seed);
|
||||
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy());
|
||||
AbstractMockVectorValues<T> finalVectorValues =
|
||||
vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
|
||||
Map<Integer, Integer> initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset);
|
||||
|
||||
HnswGraphBuilder<T> finalBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
finalVectorValues,
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
10,
|
||||
30,
|
||||
seed,
|
||||
initializerGraph,
|
||||
initializerOrdMap);
|
||||
|
||||
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
|
||||
|
||||
// Confirm that the graph is appropriately constructed by checking that the nodes in the old
|
||||
// graph are present in the levels of the new graph
|
||||
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy());
|
||||
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
|
||||
}
|
||||
|
||||
private void assertGraphContainsGraph(
|
||||
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||
for (int i = 0; i < h.numLevels(); i++) {
|
||||
int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i));
|
||||
int[] initializerGraphNodesOnLevel =
|
||||
mapArrayAndSort(nodesIteratorToArray(h.getNodesOnLevel(i)), oldToNewOrdMap);
|
||||
int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel);
|
||||
assertEquals(initializerGraphNodesOnLevel.length, overlap);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertGraphInitializedFromGraph(
|
||||
HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
|
||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||
// Confirm that the size of the new graph includes all nodes up to an including the max new
|
||||
// ordinal in the old to
|
||||
// new ordinal mapping
|
||||
assertEquals(
|
||||
"the number of nodes in the graphs are different!",
|
||||
g.size(),
|
||||
Collections.max(oldToNewOrdMap.values()) + 1);
|
||||
|
||||
// assert the nodes from the previous graph are successfully to levels > 0 in the new graph
|
||||
for (int level = 1; level < g.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = g.getNodesOnLevel(level);
|
||||
NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt());
|
||||
assertEquals("nodes in the graphs are different", node, node2);
|
||||
}
|
||||
}
|
||||
|
||||
// assert that the neighbors from the old graph are successfully transferred to the new graph
|
||||
for (int level = 0; level < g.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
|
||||
while (nodesOnLevel.hasNext()) {
|
||||
int node = nodesOnLevel.nextInt();
|
||||
g.seek(level, oldToNewOrdMap.get(node));
|
||||
h.seek(level, node);
|
||||
assertEquals(
|
||||
"arcs differ for node " + node,
|
||||
getNeighborNodes(g),
|
||||
getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Map<Integer, Integer> createOffsetOrdinalMap(
|
||||
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
|
||||
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
|
||||
// vector values
|
||||
// before the docIdOffset
|
||||
int ordinalOffset = 0;
|
||||
while (totalVectorValues.nextDoc() < docIdOffset) {
|
||||
ordinalOffset++;
|
||||
}
|
||||
|
||||
Map<Integer, Integer> offsetOrdinalMap = new HashMap<>();
|
||||
for (int curr = 0;
|
||||
totalVectorValues.docID() < docIdOffset + docIdSize;
|
||||
totalVectorValues.nextDoc()) {
|
||||
offsetOrdinalMap.put(curr, ordinalOffset + curr++);
|
||||
}
|
||||
|
||||
return offsetOrdinalMap;
|
||||
}
|
||||
|
||||
private int[] nodesIteratorToArray(NodesIterator nodesIterator) {
|
||||
int[] arr = new int[nodesIterator.size()];
|
||||
int i = 0;
|
||||
while (nodesIterator.hasNext()) {
|
||||
arr[i++] = nodesIterator.nextInt();
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
private int[] mapArrayAndSort(int[] arr, Map<Integer, Integer> map) {
|
||||
int[] mappedA = new int[arr.length];
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
mappedA[i] = map.get(arr[i]);
|
||||
}
|
||||
Arrays.sort(mappedA);
|
||||
return mappedA;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
|
@ -531,8 +773,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
HnswGraphBuilder.create(
|
||||
vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(0, vectorsCopy);
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
// now every node has tried to attach every other node as a neighbor, but
|
||||
|
@ -586,9 +828,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(0, vectorsCopy);
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
|
@ -619,9 +860,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(0, vectorsCopy);
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
|
|
|
@ -85,6 +85,34 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
|
|||
return MockByteVectorValues.fromValues(bValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<byte[]> vectorValues(
|
||||
int size,
|
||||
int dimension,
|
||||
AbstractMockVectorValues<byte[]> pregeneratedVectorValues,
|
||||
int pregeneratedOffset) {
|
||||
byte[][] vectors = new byte[size][];
|
||||
byte[][] randomVectors =
|
||||
createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random());
|
||||
|
||||
for (int i = 0; i < pregeneratedOffset; i++) {
|
||||
vectors[i] = randomVectors[i];
|
||||
}
|
||||
|
||||
int currentDoc;
|
||||
while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
|
||||
}
|
||||
|
||||
for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
|
||||
i < vectors.length;
|
||||
i++) {
|
||||
vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
|
||||
}
|
||||
|
||||
return MockByteVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<byte[]> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException {
|
||||
|
|
|
@ -79,6 +79,35 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
|
|||
return MockVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(
|
||||
int size,
|
||||
int dimension,
|
||||
AbstractMockVectorValues<float[]> pregeneratedVectorValues,
|
||||
int pregeneratedOffset) {
|
||||
float[][] vectors = new float[size][];
|
||||
float[][] randomVectors =
|
||||
createRandomFloatVectors(
|
||||
size - pregeneratedVectorValues.values.length, dimension, random());
|
||||
|
||||
for (int i = 0; i < pregeneratedOffset; i++) {
|
||||
vectors[i] = randomVectors[i];
|
||||
}
|
||||
|
||||
int currentDoc;
|
||||
while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
|
||||
}
|
||||
|
||||
for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
|
||||
i < vectors.length;
|
||||
i++) {
|
||||
vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
|
||||
}
|
||||
|
||||
return MockVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnFloatVectorField(name, vector, similarityFunction);
|
||||
|
|
Loading…
Reference in New Issue