mirror of https://github.com/apache/lucene.git
LUCENE-9004: KNN vector search using NSW graphs (#2022)
This commit is contained in:
parent
80a0154d57
commit
b36b4af22b
|
@ -283,7 +283,6 @@ public abstract class VectorWriter implements Closeable {
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,15 +27,17 @@ import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lucene 9.0 vector format, which encodes dense numeric vector values.
|
* Lucene 9.0 vector format, which encodes dense numeric vector values.
|
||||||
* TODO: add support for approximate KNN search.
|
*
|
||||||
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public final class Lucene90VectorFormat extends VectorFormat {
|
public final class Lucene90VectorFormat extends VectorFormat {
|
||||||
|
|
||||||
static final String META_CODEC_NAME = "Lucene90VectorFormatMeta";
|
static final String META_CODEC_NAME = "Lucene90VectorFormatMeta";
|
||||||
static final String VECTOR_DATA_CODEC_NAME = "Lucene90VectorFormatData";
|
static final String VECTOR_DATA_CODEC_NAME = "Lucene90VectorFormatData";
|
||||||
|
static final String VECTOR_INDEX_CODEC_NAME = "Lucene90VectorFormatIndex";
|
||||||
static final String META_EXTENSION = "vem";
|
static final String META_EXTENSION = "vem";
|
||||||
static final String VECTOR_DATA_EXTENSION = "vec";
|
static final String VECTOR_DATA_EXTENSION = "vec";
|
||||||
|
static final String VECTOR_INDEX_EXTENSION = "vex";
|
||||||
|
|
||||||
static final int VERSION_START = 0;
|
static final int VERSION_START = 0;
|
||||||
static final int VERSION_CURRENT = VERSION_START;
|
static final int VERSION_CURRENT = VERSION_START;
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
|
||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.VectorReader;
|
import org.apache.lucene.codecs.VectorReader;
|
||||||
|
@ -29,19 +30,28 @@ import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.KnnGraphValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
|
import org.apache.lucene.store.DataInput;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
|
import org.apache.lucene.util.hnsw.Neighbor;
|
||||||
|
import org.apache.lucene.util.hnsw.Neighbors;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads vectors from the index segments.
|
* Reads vectors from the index segments along with index data structures supporting KNN search.
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public final class Lucene90VectorReader extends VectorReader {
|
public final class Lucene90VectorReader extends VectorReader {
|
||||||
|
@ -49,13 +59,21 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
private final FieldInfos fieldInfos;
|
private final FieldInfos fieldInfos;
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||||
private final IndexInput vectorData;
|
private final IndexInput vectorData;
|
||||||
private final int maxDoc;
|
private final IndexInput vectorIndex;
|
||||||
|
private final long checksumSeed;
|
||||||
|
|
||||||
Lucene90VectorReader(SegmentReadState state) throws IOException {
|
Lucene90VectorReader(SegmentReadState state) throws IOException {
|
||||||
this.fieldInfos = state.fieldInfos;
|
this.fieldInfos = state.fieldInfos;
|
||||||
this.maxDoc = state.segmentInfo.maxDoc();
|
|
||||||
|
|
||||||
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.META_EXTENSION);
|
int versionMeta = readMetadata(state, Lucene90VectorFormat.META_EXTENSION);
|
||||||
|
long[] checksumRef = new long[1];
|
||||||
|
vectorData = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_DATA_EXTENSION, Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, checksumRef);
|
||||||
|
vectorIndex = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION, Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME, checksumRef);
|
||||||
|
checksumSeed = checksumRef[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
private int readMetadata(SegmentReadState state, String fileExtension) throws IOException {
|
||||||
|
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||||
int versionMeta = -1;
|
int versionMeta = -1;
|
||||||
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) {
|
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) {
|
||||||
Throwable priorE = null;
|
Throwable priorE = null;
|
||||||
|
@ -73,29 +91,32 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
CodecUtil.checkFooter(meta, priorE);
|
CodecUtil.checkFooter(meta, priorE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return versionMeta;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, long[] checksumRef) throws IOException {
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
|
|
||||||
String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION);
|
String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||||
this.vectorData = state.directory.openInput(vectorDataFileName, state.context);
|
IndexInput in = state.directory.openInput(fileName, state.context);
|
||||||
try {
|
try {
|
||||||
int versionVectorData = CodecUtil.checkIndexHeader(vectorData,
|
int versionVectorData = CodecUtil.checkIndexHeader(in,
|
||||||
Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME,
|
codecName,
|
||||||
Lucene90VectorFormat.VERSION_START,
|
Lucene90VectorFormat.VERSION_START,
|
||||||
Lucene90VectorFormat.VERSION_CURRENT,
|
Lucene90VectorFormat.VERSION_CURRENT,
|
||||||
state.segmentInfo.getId(),
|
state.segmentInfo.getId(),
|
||||||
state.segmentSuffix);
|
state.segmentSuffix);
|
||||||
if (versionMeta != versionVectorData) {
|
if (versionMeta != versionVectorData) {
|
||||||
throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", vector data=" + versionVectorData, vectorData);
|
throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in);
|
||||||
}
|
}
|
||||||
CodecUtil.retrieveChecksum(vectorData);
|
checksumRef[0] = CodecUtil.retrieveChecksum(in);
|
||||||
|
|
||||||
success = true;
|
success = true;
|
||||||
} finally {
|
} finally {
|
||||||
if (!success) {
|
if (!success) {
|
||||||
IOUtils.closeWhileHandlingException(this.vectorData);
|
IOUtils.closeWhileHandlingException(in);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return in;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
|
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
|
||||||
|
@ -104,23 +125,28 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
if (info == null) {
|
if (info == null) {
|
||||||
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||||
}
|
}
|
||||||
int searchStrategyId = meta.readInt();
|
fields.put(info.name, readField(meta));
|
||||||
if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) {
|
}
|
||||||
throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, meta);
|
}
|
||||||
}
|
|
||||||
VectorValues.SearchStrategy searchStrategy = VectorValues.SearchStrategy.values()[searchStrategyId];
|
private VectorValues.SearchStrategy readSearchStrategy(DataInput input) throws IOException {
|
||||||
long vectorDataOffset = meta.readVLong();
|
int searchStrategyId = input.readInt();
|
||||||
long vectorDataLength = meta.readVLong();
|
if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) {
|
||||||
int dimension = meta.readInt();
|
throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, input);
|
||||||
int size = meta.readInt();
|
}
|
||||||
int[] ordToDoc = new int[size];
|
return VectorValues.SearchStrategy.values()[searchStrategyId];
|
||||||
for (int i = 0; i < size; i++) {
|
}
|
||||||
int doc = meta.readVInt();
|
|
||||||
ordToDoc[i] = doc;
|
private FieldEntry readField(DataInput input) throws IOException {
|
||||||
}
|
VectorValues.SearchStrategy searchStrategy = readSearchStrategy(input);
|
||||||
FieldEntry fieldEntry = new FieldEntry(dimension, searchStrategy, maxDoc, vectorDataOffset, vectorDataLength,
|
switch(searchStrategy) {
|
||||||
ordToDoc);
|
case NONE:
|
||||||
fields.put(info.name, fieldEntry);
|
return new FieldEntry(input, searchStrategy);
|
||||||
|
case DOT_PRODUCT_HNSW:
|
||||||
|
case EUCLIDEAN_HNSW:
|
||||||
|
return new HnswGraphFieldEntry(input, searchStrategy);
|
||||||
|
default:
|
||||||
|
throw new CorruptIndexException("Unknown vector search strategy: " + searchStrategy, input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,6 +163,7 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
CodecUtil.checksumEntireFile(vectorData);
|
CodecUtil.checksumEntireFile(vectorData);
|
||||||
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -167,29 +194,58 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
return new OffHeapVectorValues(fieldEntry, bytesSlice);
|
return new OffHeapVectorValues(fieldEntry, bytesSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public KnnGraphValues getGraphValues(String field) throws IOException {
|
||||||
|
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
|
if (info == null) {
|
||||||
|
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||||
|
}
|
||||||
|
FieldEntry entry = fields.get(field);
|
||||||
|
if (entry != null && entry.indexDataLength > 0) {
|
||||||
|
return getGraphValues(entry);
|
||||||
|
} else {
|
||||||
|
return KnnGraphValues.EMPTY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
|
||||||
|
if (entry.searchStrategy.isHnsw()) {
|
||||||
|
HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry;
|
||||||
|
IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
|
||||||
|
return new IndexedKnnGraphReader(graphEntry, bytesSlice);
|
||||||
|
} else {
|
||||||
|
return KnnGraphValues.EMPTY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
vectorData.close();
|
IOUtils.close(vectorData, vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class FieldEntry {
|
private static class FieldEntry {
|
||||||
|
|
||||||
final int dimension;
|
final int dimension;
|
||||||
final VectorValues.SearchStrategy searchStrategy;
|
final VectorValues.SearchStrategy searchStrategy;
|
||||||
final int maxDoc;
|
|
||||||
|
|
||||||
final long vectorDataOffset;
|
final long vectorDataOffset;
|
||||||
final long vectorDataLength;
|
final long vectorDataLength;
|
||||||
|
final long indexDataOffset;
|
||||||
|
final long indexDataLength;
|
||||||
final int[] ordToDoc;
|
final int[] ordToDoc;
|
||||||
|
|
||||||
FieldEntry(int dimension, VectorValues.SearchStrategy searchStrategy, int maxDoc,
|
FieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException {
|
||||||
long vectorDataOffset, long vectorDataLength, int[] ordToDoc) {
|
|
||||||
this.dimension = dimension;
|
|
||||||
this.searchStrategy = searchStrategy;
|
this.searchStrategy = searchStrategy;
|
||||||
this.maxDoc = maxDoc;
|
vectorDataOffset = input.readVLong();
|
||||||
this.vectorDataOffset = vectorDataOffset;
|
vectorDataLength = input.readVLong();
|
||||||
this.vectorDataLength = vectorDataLength;
|
indexDataOffset = input.readVLong();
|
||||||
this.ordToDoc = ordToDoc;
|
indexDataLength = input.readVLong();
|
||||||
|
dimension = input.readInt();
|
||||||
|
int size = input.readInt();
|
||||||
|
ordToDoc = new int[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
int doc = input.readVInt();
|
||||||
|
ordToDoc[i] = doc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int size() {
|
int size() {
|
||||||
|
@ -197,6 +253,21 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class HnswGraphFieldEntry extends FieldEntry {
|
||||||
|
|
||||||
|
final long[] ordOffsets;
|
||||||
|
|
||||||
|
HnswGraphFieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException {
|
||||||
|
super(input, searchStrategy);
|
||||||
|
ordOffsets = new long[size()];
|
||||||
|
long offset = 0;
|
||||||
|
for (int i = 0; i < ordOffsets.length; i++) {
|
||||||
|
offset += input.readVLong();
|
||||||
|
ordOffsets[i] = offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
private final class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValuesProducer {
|
private final class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValuesProducer {
|
||||||
|
|
||||||
|
@ -252,11 +323,6 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
return binaryValue;
|
return binaryValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public TopDocs search(float[] target, int k, int fanout) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int docID() {
|
public int docID() {
|
||||||
return doc;
|
return doc;
|
||||||
|
@ -288,6 +354,30 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
return new OffHeapRandomAccess(dataIn.clone());
|
return new OffHeapRandomAccess(dataIn.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs search(float[] vector, int topK, int fanout) throws IOException {
|
||||||
|
// use a seed that is fixed for the index so we get reproducible results for the same query
|
||||||
|
final Random random = new Random(checksumSeed);
|
||||||
|
Neighbors results = HnswGraph.search(vector, topK + fanout, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random);
|
||||||
|
while (results.size() > topK) {
|
||||||
|
results.pop();
|
||||||
|
}
|
||||||
|
int i = 0;
|
||||||
|
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)];
|
||||||
|
boolean reversed = searchStrategy().reversed;
|
||||||
|
while (results.size() > 0) {
|
||||||
|
Neighbor n = results.pop();
|
||||||
|
float score;
|
||||||
|
if (reversed) {
|
||||||
|
score = (float) Math.exp(- n.score() / vector.length);
|
||||||
|
} else {
|
||||||
|
score = n.score();
|
||||||
|
}
|
||||||
|
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[n.node()], score);
|
||||||
|
}
|
||||||
|
// always return >= the case where we can assert == is only when there are fewer than topK vectors in the index
|
||||||
|
return new TopDocs(new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), scoreDocs);
|
||||||
|
}
|
||||||
|
|
||||||
class OffHeapRandomAccess implements RandomAccessVectorValues {
|
class OffHeapRandomAccess implements RandomAccessVectorValues {
|
||||||
|
|
||||||
|
@ -296,12 +386,10 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
final BytesRef binaryValue;
|
final BytesRef binaryValue;
|
||||||
final ByteBuffer byteBuffer;
|
final ByteBuffer byteBuffer;
|
||||||
final FloatBuffer floatBuffer;
|
final FloatBuffer floatBuffer;
|
||||||
final int byteSize;
|
|
||||||
final float[] value;
|
final float[] value;
|
||||||
|
|
||||||
OffHeapRandomAccess(IndexInput dataIn) {
|
OffHeapRandomAccess(IndexInput dataIn) {
|
||||||
this.dataIn = dataIn;
|
this.dataIn = dataIn;
|
||||||
byteSize = Float.BYTES * dimension();
|
|
||||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||||
floatBuffer = byteBuffer.asFloatBuffer();
|
floatBuffer = byteBuffer.asFloatBuffer();
|
||||||
value = new float[dimension()];
|
value = new float[dimension()];
|
||||||
|
@ -342,7 +430,41 @@ public final class Lucene90VectorReader extends VectorReader {
|
||||||
dataIn.seek(offset);
|
dataIn.seek(offset);
|
||||||
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Read the nearest-neighbors graph from the index input */
|
||||||
|
private final class IndexedKnnGraphReader extends KnnGraphValues {
|
||||||
|
|
||||||
|
final HnswGraphFieldEntry entry;
|
||||||
|
final IndexInput dataIn;
|
||||||
|
|
||||||
|
int arcCount;
|
||||||
|
int arcUpTo;
|
||||||
|
int arc;
|
||||||
|
|
||||||
|
IndexedKnnGraphReader(HnswGraphFieldEntry entry, IndexInput dataIn) {
|
||||||
|
this.entry = entry;
|
||||||
|
this.dataIn = dataIn;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void seek(int targetOrd) throws IOException {
|
||||||
|
// unsafe; no bounds checking
|
||||||
|
dataIn.seek(entry.ordOffsets[targetOrd]);
|
||||||
|
arcCount = dataIn.readInt();
|
||||||
|
arc = -1;
|
||||||
|
arcUpTo = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextNeighbor() throws IOException {
|
||||||
|
if (arcUpTo >= arcCount) {
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
++arcUpTo;
|
||||||
|
arc += dataIn.readVInt();
|
||||||
|
return arc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,18 +18,20 @@
|
||||||
package org.apache.lucene.codecs.lucene90;
|
package org.apache.lucene.codecs.lucene90;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.VectorWriter;
|
import org.apache.lucene.codecs.VectorWriter;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.store.IndexOutput;
|
import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
@ -39,7 +41,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
*/
|
*/
|
||||||
public final class Lucene90VectorWriter extends VectorWriter {
|
public final class Lucene90VectorWriter extends VectorWriter {
|
||||||
|
|
||||||
private final IndexOutput meta, vectorData;
|
private final IndexOutput meta, vectorData, vectorIndex;
|
||||||
|
|
||||||
private boolean finished;
|
private boolean finished;
|
||||||
|
|
||||||
|
@ -52,6 +54,9 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
||||||
String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION);
|
String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION);
|
||||||
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
|
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
|
||||||
|
|
||||||
|
String indexDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION);
|
||||||
|
vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
CodecUtil.writeIndexHeader(meta,
|
CodecUtil.writeIndexHeader(meta,
|
||||||
Lucene90VectorFormat.META_CODEC_NAME,
|
Lucene90VectorFormat.META_CODEC_NAME,
|
||||||
|
@ -61,6 +66,10 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
||||||
Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME,
|
Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME,
|
||||||
Lucene90VectorFormat.VERSION_CURRENT,
|
Lucene90VectorFormat.VERSION_CURRENT,
|
||||||
state.segmentInfo.getId(), state.segmentSuffix);
|
state.segmentInfo.getId(), state.segmentSuffix);
|
||||||
|
CodecUtil.writeIndexHeader(vectorIndex,
|
||||||
|
Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME,
|
||||||
|
Lucene90VectorFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(), state.segmentSuffix);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
IOUtils.closeWhileHandlingException(this);
|
IOUtils.closeWhileHandlingException(this);
|
||||||
}
|
}
|
||||||
|
@ -69,17 +78,47 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
||||||
@Override
|
@Override
|
||||||
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
|
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
|
||||||
long vectorDataOffset = vectorData.getFilePointer();
|
long vectorDataOffset = vectorData.getFilePointer();
|
||||||
|
|
||||||
// TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index
|
// TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index
|
||||||
List<Integer> docIds = new ArrayList<>();
|
int[] docIds = new int[vectors.size()];
|
||||||
int docV, ord = 0;
|
int count = 0;
|
||||||
for (docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), ord++) {
|
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
|
||||||
|
// write vector
|
||||||
writeVectorValue(vectors);
|
writeVectorValue(vectors);
|
||||||
docIds.add(docV);
|
docIds[count] = docV;
|
||||||
// TODO: write knn graph value
|
|
||||||
}
|
}
|
||||||
|
// count may be < vectors.size() e,g, if some documents were deleted
|
||||||
|
long[] offsets = new long[count];
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||||
|
if (vectors.searchStrategy().isHnsw()) {
|
||||||
|
if (vectors instanceof RandomAccessVectorValuesProducer) {
|
||||||
|
writeGraph(vectorIndex, (RandomAccessVectorValuesProducer) vectors, vectorIndexOffset, offsets, count);
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Indexing an HNSW graph requires a random access vector values, got " + vectors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
if (vectorDataLength > 0) {
|
if (vectorDataLength > 0) {
|
||||||
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
|
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, vectorIndexOffset, vectorIndexLength, count, docIds);
|
||||||
|
if (vectors.searchStrategy().isHnsw()) {
|
||||||
|
writeGraphOffsets(meta, offsets);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeMeta(FieldInfo field, long vectorDataOffset, long vectorDataLength, long indexDataOffset, long indexDataLength, int size, int[] docIds) throws IOException {
|
||||||
|
meta.writeInt(field.number);
|
||||||
|
meta.writeInt(field.getVectorSearchStrategy().ordinal());
|
||||||
|
meta.writeVLong(vectorDataOffset);
|
||||||
|
meta.writeVLong(vectorDataLength);
|
||||||
|
meta.writeVLong(indexDataOffset);
|
||||||
|
meta.writeVLong(indexDataLength);
|
||||||
|
meta.writeInt(field.getVectorDimension());
|
||||||
|
meta.writeInt(size);
|
||||||
|
for (int i = 0; i < size; i ++) {
|
||||||
|
// TODO: delta-encode, or write as bitset
|
||||||
|
meta.writeVInt(docIds[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,16 +129,28 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
||||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeMeta(FieldInfo field, long vectorDataOffset, long vectorDataLength, List<Integer> docIds) throws IOException {
|
private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException {
|
||||||
meta.writeInt(field.number);
|
long last = 0;
|
||||||
meta.writeInt(field.getVectorSearchStrategy().ordinal());
|
for (long offset : offsets) {
|
||||||
meta.writeVLong(vectorDataOffset);
|
out.writeVLong(offset - last);
|
||||||
meta.writeVLong(vectorDataLength);
|
last = offset;
|
||||||
meta.writeInt(field.getVectorDimension());
|
}
|
||||||
meta.writeInt(docIds.size());
|
}
|
||||||
for (Integer docId : docIds) {
|
|
||||||
// TODO: delta-encode, or write as bitset
|
private void writeGraph(IndexOutput graphData, RandomAccessVectorValuesProducer vectorValues, long graphDataOffset, long[] offsets, int count) throws IOException {
|
||||||
meta.writeVInt(docId);
|
HnswGraph graph = HnswGraphBuilder.build(vectorValues);
|
||||||
|
for (int ord = 0; ord < count; ord++) {
|
||||||
|
// write graph
|
||||||
|
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
|
||||||
|
int[] arcs = graph.getNeighbors(ord);
|
||||||
|
Arrays.sort(arcs);
|
||||||
|
graphData.writeInt(arcs.length);
|
||||||
|
int lastArc = -1; // to make the assertion work?
|
||||||
|
for (int arc : arcs) {
|
||||||
|
assert arc > lastArc : "arcs out of order: " + lastArc + "," + arc;
|
||||||
|
graphData.writeVInt(arc - lastArc);
|
||||||
|
lastArc = arc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,11 +168,12 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
||||||
}
|
}
|
||||||
if (vectorData != null) {
|
if (vectorData != null) {
|
||||||
CodecUtil.writeFooter(vectorData);
|
CodecUtil.writeFooter(vectorData);
|
||||||
|
CodecUtil.writeFooter(vectorIndex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
IOUtils.close(meta, vectorData);
|
IOUtils.close(meta, vectorData, vectorIndex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.lucene.analysis.Analyzer; // javadocs
|
||||||
public interface IndexableFieldType {
|
public interface IndexableFieldType {
|
||||||
|
|
||||||
/** True if the field's value should be stored */
|
/** True if the field's value should be stored */
|
||||||
public boolean stored();
|
boolean stored();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if this field's value should be analyzed by the
|
* True if this field's value should be analyzed by the
|
||||||
|
@ -39,7 +39,7 @@ public interface IndexableFieldType {
|
||||||
*/
|
*/
|
||||||
// TODO: shouldn't we remove this? Whether/how a field is
|
// TODO: shouldn't we remove this? Whether/how a field is
|
||||||
// tokenized is an impl detail under Field?
|
// tokenized is an impl detail under Field?
|
||||||
public boolean tokenized();
|
boolean tokenized();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if this field's indexed form should be also stored
|
* True if this field's indexed form should be also stored
|
||||||
|
@ -52,7 +52,7 @@ public interface IndexableFieldType {
|
||||||
* This option is illegal if {@link #indexOptions()} returns
|
* This option is illegal if {@link #indexOptions()} returns
|
||||||
* IndexOptions.NONE.
|
* IndexOptions.NONE.
|
||||||
*/
|
*/
|
||||||
public boolean storeTermVectors();
|
boolean storeTermVectors();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if this field's token character offsets should also
|
* True if this field's token character offsets should also
|
||||||
|
@ -61,7 +61,7 @@ public interface IndexableFieldType {
|
||||||
* This option is illegal if term vectors are not enabled for the field
|
* This option is illegal if term vectors are not enabled for the field
|
||||||
* ({@link #storeTermVectors()} is false)
|
* ({@link #storeTermVectors()} is false)
|
||||||
*/
|
*/
|
||||||
public boolean storeTermVectorOffsets();
|
boolean storeTermVectorOffsets();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if this field's token positions should also be stored
|
* True if this field's token positions should also be stored
|
||||||
|
@ -70,7 +70,7 @@ public interface IndexableFieldType {
|
||||||
* This option is illegal if term vectors are not enabled for the field
|
* This option is illegal if term vectors are not enabled for the field
|
||||||
* ({@link #storeTermVectors()} is false).
|
* ({@link #storeTermVectors()} is false).
|
||||||
*/
|
*/
|
||||||
public boolean storeTermVectorPositions();
|
boolean storeTermVectorPositions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if this field's token payloads should also be stored
|
* True if this field's token payloads should also be stored
|
||||||
|
@ -79,7 +79,7 @@ public interface IndexableFieldType {
|
||||||
* This option is illegal if term vector positions are not enabled
|
* This option is illegal if term vector positions are not enabled
|
||||||
* for the field ({@link #storeTermVectors()} is false).
|
* for the field ({@link #storeTermVectors()} is false).
|
||||||
*/
|
*/
|
||||||
public boolean storeTermVectorPayloads();
|
boolean storeTermVectorPayloads();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if normalization values should be omitted for the field.
|
* True if normalization values should be omitted for the field.
|
||||||
|
@ -87,42 +87,42 @@ public interface IndexableFieldType {
|
||||||
* This saves memory, but at the expense of scoring quality (length normalization
|
* This saves memory, but at the expense of scoring quality (length normalization
|
||||||
* will be disabled), and if you omit norms, you cannot use index-time boosts.
|
* will be disabled), and if you omit norms, you cannot use index-time boosts.
|
||||||
*/
|
*/
|
||||||
public boolean omitNorms();
|
boolean omitNorms();
|
||||||
|
|
||||||
/** {@link IndexOptions}, describing what should be
|
/** {@link IndexOptions}, describing what should be
|
||||||
* recorded into the inverted index */
|
* recorded into the inverted index */
|
||||||
public IndexOptions indexOptions();
|
IndexOptions indexOptions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DocValues {@link DocValuesType}: how the field's value will be indexed
|
* DocValues {@link DocValuesType}: how the field's value will be indexed
|
||||||
* into docValues.
|
* into docValues.
|
||||||
*/
|
*/
|
||||||
public DocValuesType docValuesType();
|
DocValuesType docValuesType();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If this is positive (representing the number of point dimensions), the field is indexed as a point.
|
* If this is positive (representing the number of point dimensions), the field is indexed as a point.
|
||||||
*/
|
*/
|
||||||
public int pointDimensionCount();
|
int pointDimensionCount();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of dimensions used for the index key
|
* The number of dimensions used for the index key
|
||||||
*/
|
*/
|
||||||
public int pointIndexDimensionCount();
|
int pointIndexDimensionCount();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of bytes in each dimension's values.
|
* The number of bytes in each dimension's values.
|
||||||
*/
|
*/
|
||||||
public int pointNumBytes();
|
int pointNumBytes();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of dimensions of the field's vector value
|
* The number of dimensions of the field's vector value
|
||||||
*/
|
*/
|
||||||
public int vectorDimension();
|
int vectorDimension();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The {@link VectorValues.SearchStrategy} of the field's vector value
|
* The {@link VectorValues.SearchStrategy} of the field's vector value
|
||||||
*/
|
*/
|
||||||
public VectorValues.SearchStrategy vectorSearchStrategy();
|
VectorValues.SearchStrategy vectorSearchStrategy();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Attributes for the field type.
|
* Attributes for the field type.
|
||||||
|
@ -132,5 +132,5 @@ public interface IndexableFieldType {
|
||||||
*
|
*
|
||||||
* @return Map
|
* @return Map
|
||||||
*/
|
*/
|
||||||
public Map<String, String> getAttributes();
|
Map<String, String> getAttributes();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.index;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Access to per-document neighbor lists in a (hierarchical) knn search graph.
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract class KnnGraphValues {
|
||||||
|
|
||||||
|
/** Sole constructor */
|
||||||
|
protected KnnGraphValues() {}
|
||||||
|
|
||||||
|
/** Move the pointer to exactly {@code target}, the id of a node in the graph.
|
||||||
|
* After this method returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
|
||||||
|
* @param target must be a valid node in the graph, ie. ≥ 0 and < {@link VectorValues#size()}.
|
||||||
|
*/
|
||||||
|
public abstract void seek(int target) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Iterates over the neighbor list. It is illegal to call this method after it returns
|
||||||
|
* NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator.
|
||||||
|
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
|
||||||
|
*/
|
||||||
|
public abstract int nextNeighbor() throws IOException;
|
||||||
|
|
||||||
|
/** Empty graph value */
|
||||||
|
public static KnnGraphValues EMPTY = new KnnGraphValues() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextNeighbor() {
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void seek(int target) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
|
@ -190,7 +190,7 @@ public final class SlowCodecReaderWrapper {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NormsProducer readerToNormsProducer(final LeafReader reader) {
|
private static NormsProducer readerToNormsProducer(final LeafReader reader) {
|
||||||
return new NormsProducer() {
|
return new NormsProducer() {
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,9 @@ import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
import static org.apache.lucene.util.VectorUtil.dotProduct;
|
||||||
|
import static org.apache.lucene.util.VectorUtil.squareDistance;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class provides access to per-document floating point vector values indexed as {@link
|
* This class provides access to per-document floating point vector values indexed as {@link
|
||||||
* org.apache.lucene.document.VectorField}.
|
* org.apache.lucene.document.VectorField}.
|
||||||
|
@ -91,15 +94,59 @@ public abstract class VectorValues extends DocIdSetIterator {
|
||||||
* determine the nearest neighbors.
|
* determine the nearest neighbors.
|
||||||
*/
|
*/
|
||||||
public enum SearchStrategy {
|
public enum SearchStrategy {
|
||||||
|
|
||||||
/** No search strategy is provided. Note: {@link VectorValues#search(float[], int, int)}
|
/** No search strategy is provided. Note: {@link VectorValues#search(float[], int, int)}
|
||||||
* is not supported for fields specifying this strategy. */
|
* is not supported for fields specifying this strategy. */
|
||||||
NONE,
|
NONE,
|
||||||
|
|
||||||
/** HNSW graph built using Euclidean distance */
|
/** HNSW graph built using Euclidean distance */
|
||||||
EUCLIDEAN_HNSW,
|
EUCLIDEAN_HNSW(true),
|
||||||
|
|
||||||
/** HNSW graph buit using dot product */
|
/** HNSW graph buit using dot product */
|
||||||
DOT_PRODUCT_HNSW
|
DOT_PRODUCT_HNSW;
|
||||||
|
|
||||||
|
/** If true, the scores associated with vector comparisons in this strategy are in reverse order; that is,
|
||||||
|
* lower scores represent more similar vectors. Otherwise, if false, higher scores represent more similar vectors.
|
||||||
|
*/
|
||||||
|
public final boolean reversed;
|
||||||
|
|
||||||
|
SearchStrategy(boolean reversed) {
|
||||||
|
this.reversed = reversed;
|
||||||
|
}
|
||||||
|
|
||||||
|
SearchStrategy() {
|
||||||
|
reversed = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates a similarity score between the two vectors with a specified function.
|
||||||
|
* @param v1 a vector
|
||||||
|
* @param v2 another vector, of the same dimension
|
||||||
|
* @return the value of the strategy's score function applied to the two vectors
|
||||||
|
*/
|
||||||
|
public float compare(float[] v1, float[] v2) {
|
||||||
|
switch (this) {
|
||||||
|
case EUCLIDEAN_HNSW:
|
||||||
|
return squareDistance(v1, v2);
|
||||||
|
case DOT_PRODUCT_HNSW:
|
||||||
|
return dotProduct(v1, v2);
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException("Incomparable search strategy: " + this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return true if vectors indexed using this strategy will be indexed using an HNSW graph
|
||||||
|
*/
|
||||||
|
public boolean isHnsw() {
|
||||||
|
switch (this) {
|
||||||
|
case EUCLIDEAN_HNSW:
|
||||||
|
case DOT_PRODUCT_HNSW:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -176,47 +176,48 @@ class VectorValuesWriter {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long cost() {
|
|
||||||
return size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(float[] target, int k, int fanout) {
|
public TopDocs search(float[] target, int k, int fanout) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long cost() {
|
||||||
|
return size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues randomAccess() {
|
public RandomAccessVectorValues randomAccess() {
|
||||||
|
|
||||||
|
// Must make a new delegate randomAccess so that we have our own distinct float[]
|
||||||
|
final RandomAccessVectorValues delegateRA = ((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
|
||||||
|
|
||||||
return new RandomAccessVectorValues() {
|
return new RandomAccessVectorValues() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int size() {
|
public int size() {
|
||||||
return delegate.size();
|
return delegateRA.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int dimension() {
|
public int dimension() {
|
||||||
return delegate.dimension();
|
return delegateRA.dimension();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SearchStrategy searchStrategy() {
|
public SearchStrategy searchStrategy() {
|
||||||
return delegate.searchStrategy();
|
return delegateRA.searchStrategy();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue(int targetOrd) throws IOException {
|
public float[] vectorValue(int targetOrd) throws IOException {
|
||||||
return randomAccess.vectorValue(ordMap[targetOrd]);
|
return delegateRA.vectorValue(ordMap[targetOrd]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int targetOrd) {
|
public BytesRef binaryValue(int targetOrd) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -252,7 +253,7 @@ class VectorValuesWriter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues randomAccess() {
|
public RandomAccessVectorValues randomAccess() {
|
||||||
return this;
|
return new BufferedVectorValues(docsWithField, vectors, dimension, searchStrategy);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
abstract class BoundsChecker {
|
||||||
|
|
||||||
|
float bound;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the bound if sample is better
|
||||||
|
*/
|
||||||
|
abstract void update(float sample);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return whether the sample exceeds (is worse than) the bound
|
||||||
|
*/
|
||||||
|
abstract boolean check(float sample);
|
||||||
|
|
||||||
|
static BoundsChecker create(boolean reversed) {
|
||||||
|
if (reversed) {
|
||||||
|
return new Min();
|
||||||
|
} else {
|
||||||
|
return new Max();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Max extends BoundsChecker {
|
||||||
|
Max() {
|
||||||
|
bound = Float.NEGATIVE_INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
void update(float sample) {
|
||||||
|
if (sample > bound) {
|
||||||
|
bound = sample;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean check(float sample) {
|
||||||
|
return sample < bound;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Min extends BoundsChecker {
|
||||||
|
|
||||||
|
Min() {
|
||||||
|
bound = Float.POSITIVE_INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
void update(float sample) {
|
||||||
|
if (sample < bound) {
|
||||||
|
bound = sample;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean check(float sample) {
|
||||||
|
return sample > bound;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,223 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.KnnGraphValues;
|
||||||
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Navigable Small-world graph. Provides efficient approximate nearest neighbor
|
||||||
|
* search for high dimensional vectors. See <a href="https://doi.org/10.1016/j.is.2013.10.006">Approximate nearest
|
||||||
|
* neighbor algorithm based on navigable small world graphs [2014]</a> and <a
|
||||||
|
* href="https://arxiv.org/abs/1603.09320">this paper [2018]</a> for details.
|
||||||
|
*
|
||||||
|
* The nomenclature is a bit different here from what's used in those papers:
|
||||||
|
*
|
||||||
|
* <h3>Hyperparameters</h3>
|
||||||
|
* <ul>
|
||||||
|
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2012 paper; it controls the number of random entry points to sample.</li>
|
||||||
|
* <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst</code> in the 2016 paper. It is the number of
|
||||||
|
* nearest neighbor candidates to track while searching the graph for each newly inserted node.</li>
|
||||||
|
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls how many of the <code>efConst</code> neighbors are
|
||||||
|
* connected to the new node</li>
|
||||||
|
* <li><code>fanout</code> the fanout parameter of {@link VectorValues#search(float[], int, int)}
|
||||||
|
* is used to control the values of <code>numSeed</code> and <code>topK</code> that are passed to this API.
|
||||||
|
* Thus <code>fanout</code> is like a combination of <code>ef</code> (search beam width) from the 2016 paper and <code>m</code> from the 2014 paper.
|
||||||
|
* </li>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>Note: The graph may be searched by multiple threads concurrently, but updates are not thread-safe. Also note: there is no notion of
|
||||||
|
* deletions. Document searching built on top of this must do its own deletion-filtering.</p>
|
||||||
|
*/
|
||||||
|
public final class HnswGraph {
|
||||||
|
|
||||||
|
private final int maxConn;
|
||||||
|
private final VectorValues.SearchStrategy searchStrategy;
|
||||||
|
|
||||||
|
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to HnswBuilder, and the
|
||||||
|
// node values are the ordinals of those vectors.
|
||||||
|
private final List<Neighbors> graph;
|
||||||
|
|
||||||
|
HnswGraph(int maxConn, VectorValues.SearchStrategy searchStrategy) {
|
||||||
|
graph = new ArrayList<>();
|
||||||
|
graph.add(Neighbors.create(maxConn, searchStrategy.reversed));
|
||||||
|
this.maxConn = maxConn;
|
||||||
|
this.searchStrategy = searchStrategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Searches for the nearest neighbors of a query vector.
|
||||||
|
* @param query search query vector
|
||||||
|
* @param topK the number of nodes to be returned
|
||||||
|
* @param numSeed the number of random entry points to sample
|
||||||
|
* @param vectors vector values
|
||||||
|
* @param graphValues the graph values. May represent the entire graph, or a level in a hierarchical graph.
|
||||||
|
* @param random a source of randomness, used for generating entry points to the graph
|
||||||
|
* @return a priority queue holding the neighbors found
|
||||||
|
*/
|
||||||
|
public static Neighbors search(float[] query, int topK, int numSeed, RandomAccessVectorValues vectors, KnnGraphValues graphValues,
|
||||||
|
Random random) throws IOException {
|
||||||
|
VectorValues.SearchStrategy searchStrategy = vectors.searchStrategy();
|
||||||
|
// TODO: use unbounded priority queue
|
||||||
|
TreeSet<Neighbor> candidates;
|
||||||
|
if (searchStrategy.reversed) {
|
||||||
|
candidates = new TreeSet<>(Comparator.reverseOrder());
|
||||||
|
} else {
|
||||||
|
candidates = new TreeSet<>();
|
||||||
|
}
|
||||||
|
int size = vectors.size();
|
||||||
|
for (int i = 0; i < numSeed && i < size; i++) {
|
||||||
|
int entryPoint = random.nextInt(size);
|
||||||
|
candidates.add(new Neighbor(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint))));
|
||||||
|
}
|
||||||
|
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
|
||||||
|
Set<Integer> visited = new HashSet<>();
|
||||||
|
// TODO: use PriorityQueue's sentinel optimization?
|
||||||
|
Neighbors results = Neighbors.create(topK, searchStrategy.reversed);
|
||||||
|
for (Neighbor c : candidates) {
|
||||||
|
visited.add(c.node());
|
||||||
|
results.insertWithOverflow(c);
|
||||||
|
}
|
||||||
|
// Set the bound to the worst current result and below reject any newly-generated candidates failing
|
||||||
|
// to exceed this bound
|
||||||
|
BoundsChecker bound = BoundsChecker.create(searchStrategy.reversed);
|
||||||
|
bound.bound = results.top().score();
|
||||||
|
while (candidates.size() > 0) {
|
||||||
|
// get the best candidate (closest or best scoring)
|
||||||
|
Neighbor c = candidates.pollLast();
|
||||||
|
if (results.size() >= topK) {
|
||||||
|
if (bound.check(c.score())) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
graphValues.seek(c.node());
|
||||||
|
int friendOrd;
|
||||||
|
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
||||||
|
if (visited.contains(friendOrd)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
visited.add(friendOrd);
|
||||||
|
float score = searchStrategy.compare(query, vectors.vectorValue(friendOrd));
|
||||||
|
if (results.size() < topK || bound.check(score) == false) {
|
||||||
|
Neighbor n = new Neighbor(friendOrd, score);
|
||||||
|
candidates.add(n);
|
||||||
|
results.insertWithOverflow(n);
|
||||||
|
bound.bound = results.top().score();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.setVisitedCount(visited.size());
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the nodes connected to the given node by its outgoing neighborNodes in an unpredictable order. Each node inserted
|
||||||
|
* by HnswGraphBuilder corresponds to a vector, and the node is the vector's ordinal.
|
||||||
|
* @param node the node whose friends are returned
|
||||||
|
*/
|
||||||
|
public int[] getNeighbors(int node) {
|
||||||
|
Neighbors neighbors = graph.get(node);
|
||||||
|
int[] result = new int[neighbors.size()];
|
||||||
|
int i = 0;
|
||||||
|
for (Neighbor n : neighbors) {
|
||||||
|
result[i++] = n.node();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Connects two nodes symmetrically, limiting the maximum number of connections from either node.
|
||||||
|
* node1 must be less than node2 and must already have been inserted to the graph */
|
||||||
|
void connectNodes(int node1, int node2, float score) {
|
||||||
|
connect(node1, node2, score);
|
||||||
|
if (node2 == graph.size()) {
|
||||||
|
addNode();
|
||||||
|
}
|
||||||
|
connect(node2, node1, score);
|
||||||
|
}
|
||||||
|
|
||||||
|
KnnGraphValues getGraphValues() {
|
||||||
|
return new HnswGraphValues();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Makes a connection from the node to a neighbor, dropping the worst connection when maxConn is exceeded
|
||||||
|
* @param node1 node to connect *from*
|
||||||
|
* @param node2 node to connect *to*
|
||||||
|
* @param score searchStrategy.score() of the vectors associated with the two nodes
|
||||||
|
*/
|
||||||
|
boolean connect(int node1, int node2, float score) {
|
||||||
|
//System.out.println(" HnswGraph.connect " + node1 + " -> " + node2);
|
||||||
|
assert node1 >= 0 && node2 >= 0;
|
||||||
|
Neighbors nn = graph.get(node1);
|
||||||
|
assert nn != null;
|
||||||
|
if (nn.size() == maxConn) {
|
||||||
|
Neighbor top = nn.top();
|
||||||
|
if (score < top.score() == nn.reversed()) {
|
||||||
|
top.update(node2, score);
|
||||||
|
nn.updateTop();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nn.add(new Neighbor(node2, score));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int addNode() {
|
||||||
|
graph.add(Neighbors.create(maxConn, searchStrategy.reversed));
|
||||||
|
return graph.size() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Present this graph as KnnGraphValues, used for searching while inserting new nodes.
|
||||||
|
*/
|
||||||
|
private class HnswGraphValues extends KnnGraphValues {
|
||||||
|
|
||||||
|
private int arcUpTo;
|
||||||
|
private int[] neighborNodes;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void seek(int targetNode) {
|
||||||
|
arcUpTo = 0;
|
||||||
|
neighborNodes = HnswGraph.this.getNeighbors(targetNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextNeighbor() {
|
||||||
|
if (arcUpTo >= neighborNodes.length) {
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
return neighborNodes[arcUpTo++];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,188 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.KnnGraphValues;
|
||||||
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the hyperparameters.
|
||||||
|
*/
|
||||||
|
public final class HnswGraphBuilder {
|
||||||
|
|
||||||
|
// default random seed for level generation
|
||||||
|
private static final long DEFAULT_RAND_SEED = System.currentTimeMillis();
|
||||||
|
|
||||||
|
// expose for testing.
|
||||||
|
public static long randSeed = DEFAULT_RAND_SEED;
|
||||||
|
|
||||||
|
// These "default" hyperparameter settings are exposed (and nonfinal) to enable performance testing
|
||||||
|
// since the indexing API doesn't provide any control over them.
|
||||||
|
|
||||||
|
// default max connections per node
|
||||||
|
public static int DEFAULT_MAX_CONN = 16;
|
||||||
|
|
||||||
|
// default candidate list size
|
||||||
|
static int DEFAULT_BEAM_WIDTH = 16;
|
||||||
|
|
||||||
|
private final int maxConn;
|
||||||
|
private final int beamWidth;
|
||||||
|
|
||||||
|
private final BoundedVectorValues boundedVectors;
|
||||||
|
private final VectorValues.SearchStrategy searchStrategy;
|
||||||
|
private final HnswGraph hnsw;
|
||||||
|
private final Random random;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using default
|
||||||
|
* hyperparameter settings, and returns the resulting graph.
|
||||||
|
* @param vectorValues the vectors whose relations are represented by the graph
|
||||||
|
*/
|
||||||
|
public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues) throws IOException {
|
||||||
|
HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues);
|
||||||
|
return builder.build(vectorValues.randomAccess());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using the given
|
||||||
|
* hyperparameter settings, and returns the resulting graph.
|
||||||
|
* @param vectorValues the vectors whose relations are represented by the graph
|
||||||
|
* @param maxConn the number of connections to make when adding a new graph node; roughly speaking the graph fanout.
|
||||||
|
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
|
||||||
|
* @param seed the seed for a random number generator used during graph construction. Provide this to ensure repeatable construction.
|
||||||
|
*/
|
||||||
|
public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues, int maxConn, int beamWidth, long seed) throws IOException {
|
||||||
|
HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues, maxConn, beamWidth, seed);
|
||||||
|
return builder.build(vectorValues.randomAccess());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads all the vectors from two copies of a random access VectorValues. Providing two copies enables efficient retrieval
|
||||||
|
* without extra data copying, while avoiding collision of the returned values.
|
||||||
|
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet accessor for the vectors
|
||||||
|
*/
|
||||||
|
HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||||
|
if (vectors == boundedVectors.raDelegate) {
|
||||||
|
throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||||
|
}
|
||||||
|
for (int node = 1; node < vectors.size(); node++) {
|
||||||
|
insert(vectors.vectorValue(node));
|
||||||
|
}
|
||||||
|
return hnsw;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Construct the builder with default configurations */
|
||||||
|
private HnswGraphBuilder(RandomAccessVectorValuesProducer vectors) {
|
||||||
|
this(vectors, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, randSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Full constructor */
|
||||||
|
HnswGraphBuilder(RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) {
|
||||||
|
RandomAccessVectorValues vectorValues = vectors.randomAccess();
|
||||||
|
searchStrategy = vectorValues.searchStrategy();
|
||||||
|
if (searchStrategy == VectorValues.SearchStrategy.NONE) {
|
||||||
|
throw new IllegalStateException("No distance function");
|
||||||
|
}
|
||||||
|
this.maxConn = maxConn;
|
||||||
|
this.beamWidth = beamWidth;
|
||||||
|
boundedVectors = new BoundedVectorValues(vectorValues);
|
||||||
|
this.hnsw = new HnswGraph(maxConn, searchStrategy);
|
||||||
|
random = new Random(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Inserts a doc with vector value to the graph */
|
||||||
|
private void insert(float[] value) throws IOException {
|
||||||
|
addGraphNode(value);
|
||||||
|
|
||||||
|
// add the vector value
|
||||||
|
boundedVectors.inc();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addGraphNode(float[] value) throws IOException {
|
||||||
|
KnnGraphValues graphValues = hnsw.getGraphValues();
|
||||||
|
Neighbors candidates = HnswGraph.search(value, beamWidth, 2 * beamWidth, boundedVectors, graphValues, random);
|
||||||
|
|
||||||
|
int node = hnsw.addNode();
|
||||||
|
|
||||||
|
// connect the nearest neighbors to the just inserted node
|
||||||
|
addNearestNeighbors(node, candidates);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addNearestNeighbors(int newNode, Neighbors neighbors) {
|
||||||
|
// connect the nearest neighbors, relying on the graph's Neighbors' priority queues to drop off distant neighbors
|
||||||
|
for (Neighbor neighbor : neighbors) {
|
||||||
|
if (hnsw.connect(newNode, neighbor.node(), neighbor.score())) {
|
||||||
|
hnsw.connect(neighbor.node(), newNode, neighbor.score());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provides a random access VectorValues view over a delegate VectorValues, bounding the maximum ord.
|
||||||
|
* TODO: get rid of this, all it does is track a counter
|
||||||
|
*/
|
||||||
|
private static class BoundedVectorValues implements RandomAccessVectorValues {
|
||||||
|
|
||||||
|
final RandomAccessVectorValues raDelegate;
|
||||||
|
|
||||||
|
int size;
|
||||||
|
|
||||||
|
BoundedVectorValues(RandomAccessVectorValues delegate) {
|
||||||
|
raDelegate = delegate;
|
||||||
|
if (delegate.size() > 0) {
|
||||||
|
// we implicitly add the first node
|
||||||
|
size = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void inc() {
|
||||||
|
++size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() { return raDelegate.dimension(); }
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorValues.SearchStrategy searchStrategy() {
|
||||||
|
return raDelegate.searchStrategy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue(int target) throws IOException {
|
||||||
|
return raDelegate.vectorValue(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
/** A neighbor node in the HNSW graph; holds the node ordinal and its distance score. */
|
||||||
|
public class Neighbor implements Comparable<Neighbor> {
|
||||||
|
|
||||||
|
private int node;
|
||||||
|
|
||||||
|
private float score;
|
||||||
|
|
||||||
|
public Neighbor(int node, float score) {
|
||||||
|
this.node = node;
|
||||||
|
this.score = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int node() {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float score() {
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
void update(int node, float score) {
|
||||||
|
this.node = node;
|
||||||
|
this.score = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int compareTo(Neighbor o) {
|
||||||
|
if (score == o.score) {
|
||||||
|
return o.node - node;
|
||||||
|
} else {
|
||||||
|
assert node != o.node : "attempt to add the same node " + node + " twice with different scores: " + score + " != " + o.score;
|
||||||
|
return Float.compare(score, o.score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object other) {
|
||||||
|
return other instanceof Neighbor
|
||||||
|
&& ((Neighbor) other).node == node;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return 39 + 61 * node;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "(" + node + ", " + score + ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.PriorityQueue;
|
||||||
|
|
||||||
|
/** Neighbors queue. */
|
||||||
|
public abstract class Neighbors extends PriorityQueue<Neighbor> {
|
||||||
|
|
||||||
|
public static Neighbors create(int maxSize, boolean reversed) {
|
||||||
|
if (reversed) {
|
||||||
|
return new ReverseNeighbors(maxSize);
|
||||||
|
} else {
|
||||||
|
return new ForwardNeighbors(maxSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean reversed();
|
||||||
|
|
||||||
|
// Used to track the number of neighbors visited during a single graph traversal
|
||||||
|
private int visitedCount;
|
||||||
|
|
||||||
|
private Neighbors(int maxSize) {
|
||||||
|
super(maxSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ForwardNeighbors extends Neighbors {
|
||||||
|
ForwardNeighbors(int maxSize) {
|
||||||
|
super(maxSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean lessThan(Neighbor a, Neighbor b) {
|
||||||
|
if (a.score() == b.score()) {
|
||||||
|
return a.node() > b.node();
|
||||||
|
}
|
||||||
|
return a.score() < b.score();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean reversed() { return false; }
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ReverseNeighbors extends Neighbors {
|
||||||
|
ReverseNeighbors(int maxSize) {
|
||||||
|
super(maxSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean lessThan(Neighbor a, Neighbor b) {
|
||||||
|
if (a.score() == b.score()) {
|
||||||
|
return a.node() > b.node();
|
||||||
|
}
|
||||||
|
return b.score() < a.score();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean reversed() { return true; }
|
||||||
|
}
|
||||||
|
|
||||||
|
void setVisitedCount(int visitedCount) {
|
||||||
|
this.visitedCount = visitedCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int visitedCount() {
|
||||||
|
return visitedCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append("Neighbors=[");
|
||||||
|
this.iterator().forEachRemaining(sb::append);
|
||||||
|
sb.append("]");
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Navigable Small-World graph, nominally Hierarchical but currently only has a single
|
||||||
|
* layer. Provides efficient approximate nearest neighbor search for high dimensional vectors.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.util.hnsw;
|
|
@ -0,0 +1,352 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.index;
|
||||||
|
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.StringField;
|
||||||
|
import org.apache.lucene.document.VectorField;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed;
|
||||||
|
|
||||||
|
/** Tests indexing of a knn-graph */
|
||||||
|
public class TestKnnGraph extends LuceneTestCase {
|
||||||
|
|
||||||
|
private static final String KNN_GRAPH_FIELD = "vector";
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
randSeed = random().nextLong();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic test of creating documents in a graph
|
||||||
|
*/
|
||||||
|
public void testBasic() throws Exception {
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||||
|
int numDoc = atLeast(10);
|
||||||
|
int dimension = atLeast(3);
|
||||||
|
float[][] values = new float[numDoc][];
|
||||||
|
for (int i = 0; i < numDoc; i++) {
|
||||||
|
if (random().nextBoolean()) {
|
||||||
|
values[i] = new float[dimension];
|
||||||
|
for (int j = 0; j < dimension; j++) {
|
||||||
|
values[i][j] = random().nextFloat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
add(iw, i, values[i]);
|
||||||
|
}
|
||||||
|
assertConsistentGraph(iw, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSingleDocument() throws Exception {
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||||
|
float[][] values = new float[][]{new float[]{0, 1, 2}};
|
||||||
|
add(iw, 0, values[0]);
|
||||||
|
assertConsistentGraph(iw, values);
|
||||||
|
iw.commit();
|
||||||
|
assertConsistentGraph(iw, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verify that the graph properties are preserved when merging
|
||||||
|
*/
|
||||||
|
public void testMerge() throws Exception {
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||||
|
int numDoc = atLeast(100);
|
||||||
|
int dimension = atLeast(10);
|
||||||
|
float[][] values = new float[numDoc][];
|
||||||
|
for (int i = 0; i < numDoc; i++) {
|
||||||
|
if (random().nextBoolean()) {
|
||||||
|
values[i] = new float[dimension];
|
||||||
|
for (int j = 0; j < dimension; j++) {
|
||||||
|
values[i][j] = random().nextFloat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
add(iw, i, values[i]);
|
||||||
|
if (random().nextInt(10) == 3) {
|
||||||
|
//System.out.println("commit @" + i);
|
||||||
|
iw.commit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (random().nextBoolean()) {
|
||||||
|
iw.forceMerge(1);
|
||||||
|
}
|
||||||
|
assertConsistentGraph(iw, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: testSorted
|
||||||
|
// TODO: testDeletions
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verify that searching does something reasonable
|
||||||
|
*/
|
||||||
|
public void testSearch() throws Exception {
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
// don't allow random merges; they mess up the docid tie-breaking assertion
|
||||||
|
IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig().setCodec(Codec.forName("Lucene90")))) {
|
||||||
|
// Add a document for every cartesian point in an NxN square so we can
|
||||||
|
// easily know which are the nearest neighbors to every point. Insert by iterating
|
||||||
|
// using a prime number that is not a divisor of N*N so that we will hit each point once,
|
||||||
|
// and chosen so that points will be inserted in a deterministic
|
||||||
|
// but somewhat distributed pattern
|
||||||
|
int n = 5, stepSize = 17;
|
||||||
|
float[][] values = new float[n * n][];
|
||||||
|
int index = 0;
|
||||||
|
for (int i = 0; i < values.length; i++) {
|
||||||
|
// System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
|
||||||
|
values[i] = new float[]{index % n, index / n};
|
||||||
|
index = (index + stepSize) % (n * n);
|
||||||
|
add(iw, i, values[i]);
|
||||||
|
if (i == 13) {
|
||||||
|
// create 2 segments
|
||||||
|
iw.commit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
boolean forceMerge = random().nextBoolean();
|
||||||
|
//System.out.println("");
|
||||||
|
if (forceMerge) {
|
||||||
|
iw.forceMerge(1);
|
||||||
|
}
|
||||||
|
assertConsistentGraph(iw, values);
|
||||||
|
try (DirectoryReader dr = DirectoryReader.open(iw)) {
|
||||||
|
// results are ordered by score (descending) and docid (ascending);
|
||||||
|
// This is the insertion order:
|
||||||
|
// column major, origin at upper left
|
||||||
|
// 0 15 5 20 10
|
||||||
|
// 3 18 8 23 13
|
||||||
|
// 6 21 11 1 16
|
||||||
|
// 9 24 14 4 19
|
||||||
|
// 12 2 17 7 22
|
||||||
|
|
||||||
|
// For this small graph the "search" is exhaustive, so this mostly tests the APIs, the orientation of the
|
||||||
|
// various priority queues, the scoring function, but not so much the approximate KNN search algo
|
||||||
|
assertGraphSearch(new int[]{0, 15, 3, 18, 5}, new float[]{0f, 0.1f}, dr);
|
||||||
|
// test tiebreaking by docid
|
||||||
|
assertGraphSearch(new int[]{11, 1, 8, 14, 21}, new float[]{2, 2}, dr);
|
||||||
|
assertGraphSearch(new int[]{15, 18, 0, 3, 5},new float[]{0.3f, 0.8f}, dr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertGraphSearch(int[] expected, float[] vector, IndexReader reader) throws IOException {
|
||||||
|
TopDocs results = doKnnSearch(reader, vector, 5);
|
||||||
|
for (ScoreDoc doc : results.scoreDocs) {
|
||||||
|
// map docId to insertion id
|
||||||
|
int id = Integer.parseInt(reader.document(doc.doc).get("id"));
|
||||||
|
doc.doc = id;
|
||||||
|
}
|
||||||
|
assertResults(expected, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException {
|
||||||
|
TopDocs[] results = new TopDocs[reader.leaves().size()];
|
||||||
|
for (LeafReaderContext ctx: reader.leaves()) {
|
||||||
|
results[ctx.ord] = ctx.reader().getVectorValues(KNN_GRAPH_FIELD)
|
||||||
|
.search(vector, k, 10);
|
||||||
|
if (ctx.docBase > 0) {
|
||||||
|
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
|
||||||
|
doc.doc += ctx.docBase;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return TopDocs.merge(k, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertResults(int[] expected, TopDocs results) {
|
||||||
|
assertEquals(results.toString(), expected.length, results.scoreDocs.length);
|
||||||
|
for (int i = expected.length - 1; i >= 0; i--) {
|
||||||
|
assertEquals(Arrays.toString(results.scoreDocs), expected[i], results.scoreDocs[i].doc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each leaf, verify that its graph nodes are 1-1 with vectors, that the vectors are the expected values,
|
||||||
|
// and that the graph is fully connected and symmetric.
|
||||||
|
// NOTE: when we impose max-fanout on the graph it wil no longer be symmetric, but should still
|
||||||
|
// be fully connected. Is there any other invariant we can test? Well, we can check that max fanout
|
||||||
|
// is respected. We can test *desirable* properties of the graph like small-world (the graph diameter
|
||||||
|
// should be tightly bounded).
|
||||||
|
private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOException {
|
||||||
|
int totalGraphDocs = 0;
|
||||||
|
try (DirectoryReader dr = DirectoryReader.open(iw)) {
|
||||||
|
for (LeafReaderContext ctx: dr.leaves()) {
|
||||||
|
LeafReader reader = ctx.reader();
|
||||||
|
VectorValues vectorValues = reader.getVectorValues(KNN_GRAPH_FIELD);
|
||||||
|
Lucene90VectorReader vectorReader = ((Lucene90VectorReader) ((CodecReader) reader).getVectorReader());
|
||||||
|
if (vectorReader == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
KnnGraphValues graphValues = vectorReader.getGraphValues(KNN_GRAPH_FIELD);
|
||||||
|
assertTrue((vectorValues == null) == (graphValues == null));
|
||||||
|
if (vectorValues == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int[][] graph = new int[reader.maxDoc()][];
|
||||||
|
boolean foundOrphan= false;
|
||||||
|
int graphSize = 0;
|
||||||
|
int node = -1;
|
||||||
|
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||||
|
int nextDocWithVectors = vectorValues.advance(i);
|
||||||
|
//System.out.println("advanced to " + nextDocWithVectors);
|
||||||
|
while (i < nextDocWithVectors && i < reader.maxDoc()) {
|
||||||
|
int id = Integer.parseInt(reader.document(i).get("id"));
|
||||||
|
assertNull("document " + id + " has no vector, but was expected to", values[id]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
if (nextDocWithVectors == NO_MORE_DOCS) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int id = Integer.parseInt(reader.document(i).get("id"));
|
||||||
|
graphValues.seek(++node);
|
||||||
|
// documents with KnnGraphValues have the expected vectors
|
||||||
|
float[] scratch = vectorValues.vectorValue();
|
||||||
|
assertArrayEquals("vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
|
||||||
|
values[id], scratch, 0f);
|
||||||
|
// We collect neighbors for analysis below
|
||||||
|
List<Integer> friends = new ArrayList<>();
|
||||||
|
int arc;
|
||||||
|
while ((arc = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
||||||
|
friends.add(arc);
|
||||||
|
}
|
||||||
|
if (friends.size() == 0) {
|
||||||
|
//System.out.printf("knngraph @%d is singleton (advance returns %d)\n", i, nextWithNeighbors);
|
||||||
|
foundOrphan = true;
|
||||||
|
} else {
|
||||||
|
// NOTE: these friends are dense ordinals, not docIds.
|
||||||
|
int[] friendCopy = new int[friends.size()];
|
||||||
|
for (int j = 0; j < friends.size(); j++) {
|
||||||
|
friendCopy[j] = friends.get(j);
|
||||||
|
}
|
||||||
|
graph[graphSize] = friendCopy;
|
||||||
|
//System.out.printf("knngraph @%d => %s\n", i, Arrays.toString(graph[i]));
|
||||||
|
}
|
||||||
|
graphSize++;
|
||||||
|
}
|
||||||
|
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||||
|
if (foundOrphan) {
|
||||||
|
assertEquals("graph is not fully connected", 1, graphSize);
|
||||||
|
} else {
|
||||||
|
assertTrue("Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
|
||||||
|
}
|
||||||
|
// assert that the graph in each leaf is connected and undirected (ie links are reciprocated)
|
||||||
|
// assertReciprocal(graph);
|
||||||
|
assertConnected(graph);
|
||||||
|
assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN);
|
||||||
|
totalGraphDocs += graphSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int expectedCount = 0;
|
||||||
|
for (float[] friends : values) {
|
||||||
|
if (friends != null) {
|
||||||
|
++expectedCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals(expectedCount, totalGraphDocs);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertMaxConn(int[][] graph, int maxConn) {
|
||||||
|
for (int i = 0; i < graph.length; i++) {
|
||||||
|
if (graph[i] != null) {
|
||||||
|
assert (graph[i].length <= maxConn);
|
||||||
|
for (int j = 0; j < graph[i].length; j++) {
|
||||||
|
int k = graph[i][j];
|
||||||
|
assertNotNull(graph[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertReciprocal(int[][] graph) {
|
||||||
|
// The graph is undirected: if a -> b then b -> a.
|
||||||
|
for (int i = 0; i < graph.length; i++) {
|
||||||
|
if (graph[i] != null) {
|
||||||
|
for (int j = 0; j < graph[i].length; j++) {
|
||||||
|
int k = graph[i][j];
|
||||||
|
assertNotNull(graph[k]);
|
||||||
|
assertTrue("" + i + "->" + k + " is not reciprocated", Arrays.binarySearch(graph[k], i) >= 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertConnected(int[][] graph) {
|
||||||
|
// every node in the graph is reachable from every other node
|
||||||
|
Set<Integer> visited = new HashSet<>();
|
||||||
|
List<Integer> queue = new LinkedList<>();
|
||||||
|
int count = 0;
|
||||||
|
for (int[] entry : graph) {
|
||||||
|
if (entry != null) {
|
||||||
|
if (queue.isEmpty()) {
|
||||||
|
queue.add(entry[0]); // start from any node
|
||||||
|
//System.out.println("start at " + entry[0]);
|
||||||
|
}
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while(queue.isEmpty() == false) {
|
||||||
|
int i = queue.remove(0);
|
||||||
|
assertNotNull("expected neighbors of " + i, graph[i]);
|
||||||
|
visited.add(i);
|
||||||
|
for (int j : graph[i]) {
|
||||||
|
if (visited.contains(j) == false) {
|
||||||
|
//System.out.println(" ... " + j);
|
||||||
|
queue.add(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// we visited each node exactly once
|
||||||
|
assertEquals("Attempted to walk entire graph but only visited " + visited.size(), count, visited.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void add(IndexWriter iw, int id, float[] vector) throws IOException {
|
||||||
|
Document doc = new Document();
|
||||||
|
if (vector != null) {
|
||||||
|
// TODO: choose random search strategy
|
||||||
|
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW));
|
||||||
|
}
|
||||||
|
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
|
||||||
|
//System.out.println("add " + id + " " + Arrays.toString(vector));
|
||||||
|
iw.addDocument(doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -601,6 +601,51 @@ public class TestVectorValues extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testIndexMultipleVectorFields() throws Exception {
|
||||||
|
try (Directory dir = newDirectory();
|
||||||
|
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||||
|
Document doc = new Document();
|
||||||
|
float[] v = new float[]{1};
|
||||||
|
doc.add(new VectorField("field1", v, SearchStrategy.EUCLIDEAN_HNSW));
|
||||||
|
doc.add(new VectorField("field2", new float[]{1, 2, 3}, SearchStrategy.NONE));
|
||||||
|
iw.addDocument(doc);
|
||||||
|
v[0] = 2;
|
||||||
|
iw.addDocument(doc);
|
||||||
|
doc = new Document();
|
||||||
|
doc.add(new VectorField("field3", new float[]{1, 2, 3}, SearchStrategy.DOT_PRODUCT_HNSW));
|
||||||
|
iw.addDocument(doc);
|
||||||
|
iw.forceMerge(1);
|
||||||
|
try (IndexReader reader = iw.getReader()) {
|
||||||
|
LeafReader leaf = reader.leaves().get(0).reader();
|
||||||
|
|
||||||
|
VectorValues vectorValues = leaf.getVectorValues("field1");
|
||||||
|
assertEquals(1, vectorValues.dimension());
|
||||||
|
assertEquals(2, vectorValues.size());
|
||||||
|
vectorValues.nextDoc();
|
||||||
|
assertEquals(1f, vectorValues.vectorValue()[0], 0);
|
||||||
|
vectorValues.nextDoc();
|
||||||
|
assertEquals(2f, vectorValues.vectorValue()[0], 0);
|
||||||
|
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||||
|
|
||||||
|
VectorValues vectorValues2 = leaf.getVectorValues("field2");
|
||||||
|
assertEquals(3, vectorValues2.dimension());
|
||||||
|
assertEquals(2, vectorValues2.size());
|
||||||
|
vectorValues2.nextDoc();
|
||||||
|
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
|
||||||
|
vectorValues2.nextDoc();
|
||||||
|
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
|
||||||
|
assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc());
|
||||||
|
|
||||||
|
VectorValues vectorValues3 = leaf.getVectorValues("field3");
|
||||||
|
assertEquals(3, vectorValues3.dimension());
|
||||||
|
assertEquals(1, vectorValues3.size());
|
||||||
|
vectorValues3.nextDoc();
|
||||||
|
assertEquals(1f, vectorValues3.vectorValue()[0], 0);
|
||||||
|
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Index random vectors, sometimes skipping documents, sometimes deleting a document,
|
* Index random vectors, sometimes skipping documents, sometimes deleting a document,
|
||||||
* sometimes merging, sometimes sorting the index,
|
* sometimes merging, sometimes sorting the index,
|
||||||
|
|
|
@ -0,0 +1,494 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.lang.management.ManagementFactory;
|
||||||
|
import java.lang.management.ThreadMXBean;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
import java.nio.IntBuffer;
|
||||||
|
import java.nio.channels.FileChannel;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Locale;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.StoredField;
|
||||||
|
import org.apache.lucene.document.VectorField;
|
||||||
|
import org.apache.lucene.index.CodecReader;
|
||||||
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.IndexWriter;
|
||||||
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
|
import org.apache.lucene.index.KnnGraphValues;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.store.FSDirectory;
|
||||||
|
import org.apache.lucene.util.PrintStreamInfoStream;
|
||||||
|
import org.apache.lucene.util.SuppressForbidden;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
/** For testing indexing and search performance of a knn-graph
|
||||||
|
*
|
||||||
|
* java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search .../vectors.bin
|
||||||
|
*/
|
||||||
|
public class KnnGraphTester {
|
||||||
|
|
||||||
|
private final static String KNN_FIELD = "knn";
|
||||||
|
private final static String ID_FIELD = "id";
|
||||||
|
private final static VectorValues.SearchStrategy SEARCH_STRATEGY = VectorValues.SearchStrategy.DOT_PRODUCT_HNSW;
|
||||||
|
|
||||||
|
private int numDocs;
|
||||||
|
private int dim;
|
||||||
|
private int topK;
|
||||||
|
private int numIters;
|
||||||
|
private int fanout;
|
||||||
|
private Path indexPath;
|
||||||
|
private boolean quiet;
|
||||||
|
private boolean reindex;
|
||||||
|
private int reindexTimeMsec;
|
||||||
|
|
||||||
|
@SuppressForbidden(reason="uses Random()")
|
||||||
|
private KnnGraphTester() {
|
||||||
|
// set defaults
|
||||||
|
numDocs = 1000;
|
||||||
|
numIters = 1000;
|
||||||
|
dim = 256;
|
||||||
|
topK = 100;
|
||||||
|
fanout = topK;
|
||||||
|
indexPath = Paths.get("knn_test_index");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String... args) throws Exception {
|
||||||
|
new KnnGraphTester().run(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void run(String... args) throws Exception {
|
||||||
|
String operation = null, docVectorsPath = null, queryPath = null;
|
||||||
|
for (int iarg = 0; iarg < args.length; iarg++) {
|
||||||
|
String arg = args[iarg];
|
||||||
|
switch(arg) {
|
||||||
|
case "-generate":
|
||||||
|
case "-search":
|
||||||
|
case "-check":
|
||||||
|
case "-stats":
|
||||||
|
if (operation != null) {
|
||||||
|
throw new IllegalArgumentException("Specify only one operation, not both " + arg + " and " + operation);
|
||||||
|
}
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("Operation " + arg + " requires a following pathname");
|
||||||
|
}
|
||||||
|
operation = arg;
|
||||||
|
docVectorsPath = args[++iarg];
|
||||||
|
if (operation.equals("-search")) {
|
||||||
|
queryPath = args[++iarg];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "-fanout":
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("-fanout requires a following number");
|
||||||
|
}
|
||||||
|
fanout = Integer.parseInt(args[++iarg]);
|
||||||
|
break;
|
||||||
|
case "-beamWidthIndex":
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("-beamWidthIndex requires a following number");
|
||||||
|
}
|
||||||
|
HnswGraphBuilder.DEFAULT_BEAM_WIDTH = Integer.parseInt(args[++iarg]);
|
||||||
|
break;
|
||||||
|
case "-maxConn":
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("-maxConn requires a following number");
|
||||||
|
}
|
||||||
|
HnswGraphBuilder.DEFAULT_MAX_CONN = Integer.parseInt(args[++iarg]);
|
||||||
|
break;
|
||||||
|
case "-dim":
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("-dim requires a following number");
|
||||||
|
}
|
||||||
|
dim = Integer.parseInt(args[++iarg]);
|
||||||
|
break;
|
||||||
|
case "-ndoc":
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("-ndoc requires a following number");
|
||||||
|
}
|
||||||
|
numDocs = Integer.parseInt(args[++iarg]);
|
||||||
|
break;
|
||||||
|
case "-niter":
|
||||||
|
if (iarg == args.length - 1) {
|
||||||
|
throw new IllegalArgumentException("-niter requires a following number");
|
||||||
|
}
|
||||||
|
numIters = Integer.parseInt(args[++iarg]);
|
||||||
|
break;
|
||||||
|
case "-reindex":
|
||||||
|
reindex = true;
|
||||||
|
break;
|
||||||
|
case "-forceMerge":
|
||||||
|
operation = "-forceMerge";
|
||||||
|
break;
|
||||||
|
case "-quiet":
|
||||||
|
quiet = true;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("unknown argument " + arg);
|
||||||
|
//usage();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (operation == null) {
|
||||||
|
usage();
|
||||||
|
}
|
||||||
|
if (reindex) {
|
||||||
|
reindexTimeMsec = createIndex(Paths.get(docVectorsPath), indexPath);
|
||||||
|
}
|
||||||
|
switch (operation) {
|
||||||
|
case "-search":
|
||||||
|
testSearch(indexPath, Paths.get(queryPath), getNN(Paths.get(docVectorsPath), Paths.get(queryPath)));
|
||||||
|
break;
|
||||||
|
case "-forceMerge":
|
||||||
|
forceMerge();
|
||||||
|
break;
|
||||||
|
case "-stats":
|
||||||
|
printFanoutHist(indexPath);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressForbidden(reason="Prints stuff")
|
||||||
|
private void printFanoutHist(Path indexPath) throws IOException {
|
||||||
|
try (Directory dir = FSDirectory.open(indexPath);
|
||||||
|
DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
|
// int[] globalHist = new int[reader.maxDoc()];
|
||||||
|
for (LeafReaderContext context : reader.leaves()) {
|
||||||
|
LeafReader leafReader = context.reader();
|
||||||
|
KnnGraphValues knnValues = ((Lucene90VectorReader) ((CodecReader) leafReader).getVectorReader()).getGraphValues(KNN_FIELD);
|
||||||
|
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
|
||||||
|
printGraphFanout(knnValues, leafReader.maxDoc());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressForbidden(reason="Prints stuff")
|
||||||
|
private void forceMerge() throws IOException {
|
||||||
|
IndexWriterConfig iwc = new IndexWriterConfig()
|
||||||
|
.setOpenMode(IndexWriterConfig.OpenMode.APPEND);
|
||||||
|
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||||
|
System.out.println("Force merge index in " + indexPath);
|
||||||
|
try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) {
|
||||||
|
iw.forceMerge(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressForbidden(reason="Prints stuff")
|
||||||
|
private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException {
|
||||||
|
int min = Integer.MAX_VALUE, max = 0, total = 0;
|
||||||
|
int count = 0;
|
||||||
|
int[] leafHist = new int[numDocs];
|
||||||
|
for (int node = 0; node < numDocs; node++) {
|
||||||
|
knnValues.seek(node);
|
||||||
|
int n = 0;
|
||||||
|
while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
|
||||||
|
++n;
|
||||||
|
}
|
||||||
|
++leafHist[n];
|
||||||
|
max = Math.max(max, n);
|
||||||
|
min = Math.min(min, n);
|
||||||
|
if (n > 0) {
|
||||||
|
++count;
|
||||||
|
total += n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
System.out.printf("Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n", count, min, total / (float) count, max);
|
||||||
|
printHist(leafHist, max, count, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressForbidden(reason="Prints stuff")
|
||||||
|
private void printHist(int[] hist, int max, int count, int nbuckets) {
|
||||||
|
System.out.print("%");
|
||||||
|
for (int i=0; i <= nbuckets; i ++) {
|
||||||
|
System.out.printf("%4d", i * 100 / nbuckets);
|
||||||
|
}
|
||||||
|
System.out.printf("\n %4d", hist[0]);
|
||||||
|
int total = 0, ibucket = 1;
|
||||||
|
for (int i = 1; i <= max && ibucket <= nbuckets; i++) {
|
||||||
|
total += hist[i];
|
||||||
|
while (total >= count * ibucket / nbuckets) {
|
||||||
|
System.out.printf("%4d", i);
|
||||||
|
++ibucket;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
System.out.println();
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressForbidden(reason="Prints stuff")
|
||||||
|
private void testSearch(Path indexPath, Path queryPath, int[][] nn) throws IOException {
|
||||||
|
TopDocs[] results = new TopDocs[numIters];
|
||||||
|
long elapsed, totalCpuTime, totalVisited = 0;
|
||||||
|
try (FileChannel q = FileChannel.open(queryPath)) {
|
||||||
|
FloatBuffer targets = q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asFloatBuffer();
|
||||||
|
float[] target = new float[dim];
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
|
||||||
|
}
|
||||||
|
long start;
|
||||||
|
ThreadMXBean bean = ManagementFactory.getThreadMXBean();
|
||||||
|
long cpuTimeStartNs;
|
||||||
|
try (Directory dir = FSDirectory.open(indexPath);
|
||||||
|
DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
|
|
||||||
|
for (int i = 0; i < 1000; i++) {
|
||||||
|
// warm up
|
||||||
|
targets.get(target);
|
||||||
|
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
|
||||||
|
}
|
||||||
|
targets.position(0);
|
||||||
|
start = System.nanoTime();
|
||||||
|
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
|
||||||
|
for (int i = 0; i < numIters; i++) {
|
||||||
|
targets.get(target);
|
||||||
|
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
|
||||||
|
}
|
||||||
|
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
|
||||||
|
elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
|
||||||
|
for (int i = 0; i < numIters; i++) {
|
||||||
|
totalVisited += results[i].totalHits.value;
|
||||||
|
for (ScoreDoc doc : results[i].scoreDocs) {
|
||||||
|
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("completed " + numIters + " searches in " + elapsed + " ms: " + ((1000 * numIters) / elapsed) + " QPS "
|
||||||
|
+ "CPU time=" + totalCpuTime + "ms");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("checking results");
|
||||||
|
}
|
||||||
|
float recall = checkResults(results, nn);
|
||||||
|
totalVisited /= numIters;
|
||||||
|
if (quiet) {
|
||||||
|
System.out.printf(Locale.ROOT, "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n", recall, totalCpuTime / (float) numIters,
|
||||||
|
numDocs, fanout, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, totalVisited, reindexTimeMsec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TopDocs doKnnSearch(IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
|
||||||
|
TopDocs[] results = new TopDocs[reader.leaves().size()];
|
||||||
|
for (LeafReaderContext ctx: reader.leaves()) {
|
||||||
|
results[ctx.ord] = ctx.reader().getVectorValues(field).search(vector, k, fanout);
|
||||||
|
int docBase = ctx.docBase;
|
||||||
|
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
|
||||||
|
scoreDoc.doc += docBase;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return TopDocs.merge(k, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
private float checkResults(TopDocs[] results, int[][] nn) {
|
||||||
|
int totalMatches = 0;
|
||||||
|
int totalResults = 0;
|
||||||
|
for (int i = 0; i < results.length; i++) {
|
||||||
|
int n = results[i].scoreDocs.length;
|
||||||
|
totalResults += n;
|
||||||
|
//System.out.println(Arrays.toString(nn[i]));
|
||||||
|
//System.out.println(Arrays.toString(results[i].scoreDocs));
|
||||||
|
totalMatches += compareNN(nn[i], results[i]);
|
||||||
|
}
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("total matches = " + totalMatches + " out of " + totalResults);
|
||||||
|
System.out.printf(Locale.ROOT, "Average overlap = %.2f%%\n", ((100.0 * totalMatches) / totalResults));
|
||||||
|
}
|
||||||
|
return totalMatches / (float) totalResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int compareNN(int[] expected, TopDocs results) {
|
||||||
|
int matched = 0;
|
||||||
|
/*
|
||||||
|
System.out.print("expected=");
|
||||||
|
for (int j = 0; j < expected.length; j++) {
|
||||||
|
System.out.print(expected[j]);
|
||||||
|
System.out.print(", ");
|
||||||
|
}
|
||||||
|
System.out.print('\n');
|
||||||
|
System.out.println("results=");
|
||||||
|
for (int j = 0; j < results.scoreDocs.length; j++) {
|
||||||
|
System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", ");
|
||||||
|
}
|
||||||
|
System.out.print('\n');
|
||||||
|
*/
|
||||||
|
Set<Integer> expectedSet = new HashSet<>();
|
||||||
|
for (int i = 0; i < results.scoreDocs.length; i++) {
|
||||||
|
expectedSet.add(expected[i]);
|
||||||
|
}
|
||||||
|
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||||
|
if (expectedSet.contains(scoreDoc.doc)) {
|
||||||
|
++matched;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matched;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int[][] getNN(Path docPath, Path queryPath) throws IOException {
|
||||||
|
// look in working directory for cached nn file
|
||||||
|
String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin";
|
||||||
|
Path nnPath = Paths.get(nnFileName);
|
||||||
|
if (Files.exists(nnPath)) {
|
||||||
|
return readNN(nnPath);
|
||||||
|
} else {
|
||||||
|
int[][] nn = computeNN(docPath, queryPath);
|
||||||
|
writeNN(nn, nnPath);
|
||||||
|
return nn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int[][] readNN(Path nnPath) throws IOException {
|
||||||
|
int[][] result = new int[numIters][];
|
||||||
|
try (FileChannel in = FileChannel.open(nnPath)) {
|
||||||
|
IntBuffer intBuffer = in.map(FileChannel.MapMode.READ_ONLY, 0, numIters * topK * Integer.BYTES)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asIntBuffer();
|
||||||
|
for (int i = 0; i < numIters; i++) {
|
||||||
|
result[i] = new int[topK];
|
||||||
|
intBuffer.get(result[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeNN(int[][] nn, Path nnPath) throws IOException {
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("writing true nearest neighbors to " + nnPath);
|
||||||
|
}
|
||||||
|
ByteBuffer tmp = ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
try (OutputStream out = Files.newOutputStream(nnPath)) {
|
||||||
|
for (int i = 0; i < numIters; i++) {
|
||||||
|
tmp.asIntBuffer().put(nn[i]);
|
||||||
|
out.write(tmp.array());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
|
||||||
|
int[][] result = new int[numIters][];
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
|
||||||
|
}
|
||||||
|
try (FileChannel in = FileChannel.open(docPath);
|
||||||
|
FileChannel qIn = FileChannel.open(queryPath)) {
|
||||||
|
FloatBuffer queries = qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asFloatBuffer();
|
||||||
|
float[] vector = new float[dim];
|
||||||
|
float[] query = new float[dim];
|
||||||
|
for (int i = 0; i < numIters; i++) {
|
||||||
|
queries.get(query);
|
||||||
|
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
||||||
|
int blockSize = (int) Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)), offset = 0;
|
||||||
|
int j = 0;
|
||||||
|
//System.out.println("totalBytes=" + totalBytes);
|
||||||
|
while (j < numDocs) {
|
||||||
|
FloatBuffer vectors = in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asFloatBuffer();
|
||||||
|
offset += blockSize;
|
||||||
|
Neighbors queue = Neighbors.create(topK, SEARCH_STRATEGY.reversed);
|
||||||
|
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||||
|
vectors.get(vector);
|
||||||
|
float d = SEARCH_STRATEGY.compare(query, vector);
|
||||||
|
queue.insertWithOverflow(new Neighbor(j, d));
|
||||||
|
}
|
||||||
|
result[i] = new int[topK];
|
||||||
|
for (int k = topK - 1; k >= 0; k--) {
|
||||||
|
Neighbor n = queue.pop();
|
||||||
|
result[i][k] = n.node();
|
||||||
|
//System.out.print(" " + n);
|
||||||
|
}
|
||||||
|
if (quiet == false && (i + 1) % 10 == 0) {
|
||||||
|
System.out.print(" " + (i + 1));
|
||||||
|
System.out.flush();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int createIndex(Path docsPath, Path indexPath) throws IOException {
|
||||||
|
IndexWriterConfig iwc = new IndexWriterConfig()
|
||||||
|
.setOpenMode(IndexWriterConfig.OpenMode.CREATE);
|
||||||
|
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||||
|
iwc.setRAMBufferSizeMB(1994d);
|
||||||
|
if (quiet == false) {
|
||||||
|
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||||
|
System.out.println("creating index in " + indexPath);
|
||||||
|
}
|
||||||
|
long start = System.nanoTime();
|
||||||
|
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
|
||||||
|
try (FSDirectory dir = FSDirectory.open(indexPath);
|
||||||
|
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||||
|
int blockSize = (int) Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES));
|
||||||
|
float[] vector = new float[dim];
|
||||||
|
try (FileChannel in = FileChannel.open(docsPath)) {
|
||||||
|
int i = 0;
|
||||||
|
while (i < numDocs) {
|
||||||
|
FloatBuffer vectors = in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
.asFloatBuffer();
|
||||||
|
offset += blockSize;
|
||||||
|
for (; vectors.hasRemaining() && i < numDocs ; i++) {
|
||||||
|
vectors.get(vector);
|
||||||
|
Document doc = new Document();
|
||||||
|
//System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
|
||||||
|
doc.add(new VectorField(KNN_FIELD, vector, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW));
|
||||||
|
doc.add(new StoredField(ID_FIELD, i));
|
||||||
|
iw.addDocument(doc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("Done indexing " + numDocs + " documents; now flush");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
long elapsed = System.nanoTime() - start;
|
||||||
|
if (quiet == false) {
|
||||||
|
System.out.println("Indexed " + numDocs + " documents in " + elapsed / 1_000_000_000 + "s");
|
||||||
|
}
|
||||||
|
return (int) (elapsed / 1_000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void usage() {
|
||||||
|
String error = "Usage: TestKnnGraph -generate|-search|-stats|-check {datafile} [-beamWidth N]";
|
||||||
|
System.err.println(error);
|
||||||
|
System.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,459 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.StoredField;
|
||||||
|
import org.apache.lucene.document.VectorField;
|
||||||
|
import org.apache.lucene.index.CodecReader;
|
||||||
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.IndexWriter;
|
||||||
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
|
import org.apache.lucene.index.KnnGraphValues;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
/** Tests HNSW KNN graphs */
|
||||||
|
public class TestHnsw extends LuceneTestCase {
|
||||||
|
|
||||||
|
// test writing out and reading in a graph gives the same graph
|
||||||
|
public void testReadWrite() throws IOException {
|
||||||
|
int dim = random().nextInt(100) + 1;
|
||||||
|
int nDoc = random().nextInt(100) + 1;
|
||||||
|
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
||||||
|
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||||
|
long seed = random().nextLong();
|
||||||
|
HnswGraphBuilder.randSeed = seed;
|
||||||
|
HnswGraph hnsw = HnswGraphBuilder.build((RandomAccessVectorValuesProducer) vectors);
|
||||||
|
// Recreate the graph while indexing with the same random seed and write it out
|
||||||
|
HnswGraphBuilder.randSeed = seed;
|
||||||
|
try (Directory dir = newDirectory()) {
|
||||||
|
int nVec = 0, indexedDoc = 0;
|
||||||
|
// Don't merge randomly, create a single segment because we rely on the docid ordering for this test
|
||||||
|
IndexWriterConfig iwc = new IndexWriterConfig()
|
||||||
|
.setCodec(Codec.forName("Lucene90"));
|
||||||
|
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||||
|
while (v2.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
while (indexedDoc < v2.docID()) {
|
||||||
|
// increment docId in the index by adding empty documents
|
||||||
|
iw.addDocument(new Document());
|
||||||
|
indexedDoc++;
|
||||||
|
}
|
||||||
|
Document doc = new Document();
|
||||||
|
doc.add(new VectorField("field", v2.vectorValue(), v2.searchStrategy));
|
||||||
|
doc.add(new StoredField("id", v2.docID()));
|
||||||
|
iw.addDocument(doc);
|
||||||
|
nVec++;
|
||||||
|
indexedDoc++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||||
|
for (LeafReaderContext ctx : reader.leaves()) {
|
||||||
|
VectorValues values = ctx.reader().getVectorValues("field");
|
||||||
|
assertEquals(vectors.searchStrategy, values.searchStrategy());
|
||||||
|
assertEquals(dim, values.dimension());
|
||||||
|
assertEquals(nVec, values.size());
|
||||||
|
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
||||||
|
assertEquals(indexedDoc, ctx.reader().numDocs());
|
||||||
|
assertVectorsEqual(v3, values);
|
||||||
|
KnnGraphValues graphValues = ((Lucene90VectorReader) ((CodecReader) ctx.reader()).getVectorReader()).getGraphValues("field");
|
||||||
|
assertGraphEqual(hnsw.getGraphValues(), graphValues, nVec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we actually approximately find the closest k elements. Mostly this is about
|
||||||
|
// ensuring that we have all the distance functions, comparators, priority queues and so on
|
||||||
|
// oriented in the right directions
|
||||||
|
public void testAknn() throws IOException {
|
||||||
|
int nDoc = 100;
|
||||||
|
RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc);
|
||||||
|
HnswGraph hnsw = HnswGraphBuilder.build(vectors);
|
||||||
|
// run some searches
|
||||||
|
Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random());
|
||||||
|
int sum = 0;
|
||||||
|
for (Neighbor n : nn) {
|
||||||
|
sum += n.node();
|
||||||
|
}
|
||||||
|
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45
|
||||||
|
assertTrue("sum(result docs)=" + sum, sum < 75);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMaxConnections() throws Exception {
|
||||||
|
// verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors
|
||||||
|
HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW);
|
||||||
|
graph.connectNodes(0, 1, 1);
|
||||||
|
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
|
||||||
|
graph.connectNodes(0, 2, 2);
|
||||||
|
assertArrayEquals(new int[]{2}, graph.getNeighbors(0));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
|
||||||
|
graph.connectNodes(2, 3, 1);
|
||||||
|
assertArrayEquals(new int[]{2}, graph.getNeighbors(0));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
|
||||||
|
assertArrayEquals(new int[]{2}, graph.getNeighbors(3));
|
||||||
|
|
||||||
|
graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
|
||||||
|
graph.connectNodes(0, 1, 1);
|
||||||
|
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
|
||||||
|
graph.connectNodes(0, 2, 2);
|
||||||
|
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
|
||||||
|
graph.connectNodes(2, 3, 1);
|
||||||
|
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
|
||||||
|
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
|
||||||
|
assertArrayEquals(new int[]{3}, graph.getNeighbors(2));
|
||||||
|
assertArrayEquals(new int[]{2}, graph.getNeighbors(3));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns vectors evenly distributed around the unit circle.
|
||||||
|
*/
|
||||||
|
class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||||
|
private final int size;
|
||||||
|
private final float[] value;
|
||||||
|
|
||||||
|
int doc = -1;
|
||||||
|
|
||||||
|
CircularVectorValues(int size) {
|
||||||
|
this.size = size;
|
||||||
|
value = new float[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
public CircularVectorValues copy() {
|
||||||
|
return new CircularVectorValues(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SearchStrategy searchStrategy() {
|
||||||
|
return SearchStrategy.DOT_PRODUCT_HNSW;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue() {
|
||||||
|
return vectorValue(doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues randomAccess() {
|
||||||
|
return new CircularVectorValues(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() {
|
||||||
|
return advance(doc + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) {
|
||||||
|
if (target >= 0 && target < size) {
|
||||||
|
doc = target;
|
||||||
|
} else {
|
||||||
|
doc = NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long cost() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue(int ord) {
|
||||||
|
value[0] = (float) Math.cos(Math.PI * ord / (double) size);
|
||||||
|
value[1] = (float) Math.sin(Math.PI * ord / (double) size);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef binaryValue(int ord) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs search(float[] target, int k, int fanout) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
|
||||||
|
for (int node = 0; node < size; node ++) {
|
||||||
|
g.seek(node);
|
||||||
|
h.seek(node);
|
||||||
|
assertEquals("arcs differ for node " + node, getNeighbors(g), getNeighbors(h));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Set<Integer> getNeighbors(KnnGraphValues g) throws IOException {
|
||||||
|
Set<Integer> neighbors = new HashSet<>();
|
||||||
|
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
|
||||||
|
neighbors.add(n);
|
||||||
|
}
|
||||||
|
return neighbors;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException {
|
||||||
|
int uDoc, vDoc;
|
||||||
|
while (true) {
|
||||||
|
uDoc = u.nextDoc();
|
||||||
|
vDoc = v.nextDoc();
|
||||||
|
assertEquals(uDoc, vDoc);
|
||||||
|
if (uDoc == NO_MORE_DOCS) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
assertArrayEquals("vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testNeighbors() {
|
||||||
|
// make sure we have the sign correct
|
||||||
|
Neighbors nn = Neighbors.create(2, false);
|
||||||
|
Neighbor a = new Neighbor(1, 10);
|
||||||
|
Neighbor b = new Neighbor(2, 20);
|
||||||
|
Neighbor c = new Neighbor(3, 30);
|
||||||
|
assertNull(nn.insertWithOverflow(b));
|
||||||
|
assertNull(nn.insertWithOverflow(a));
|
||||||
|
assertSame(a, nn.insertWithOverflow(c));
|
||||||
|
assertEquals(20, (int) nn.top().score());
|
||||||
|
assertEquals(20, (int) nn.pop().score());
|
||||||
|
assertEquals(30, (int) nn.top().score());
|
||||||
|
assertEquals(30, (int) nn.pop().score());
|
||||||
|
|
||||||
|
Neighbors fn = Neighbors.create(2, true);
|
||||||
|
assertNull(fn.insertWithOverflow(b));
|
||||||
|
assertNull(fn.insertWithOverflow(a));
|
||||||
|
assertSame(c, fn.insertWithOverflow(c));
|
||||||
|
assertEquals(20, (int) fn.top().score());
|
||||||
|
assertEquals(20, (int) fn.pop().score());
|
||||||
|
assertEquals(10, (int) fn.top().score());
|
||||||
|
assertEquals(10, (int) fn.pop().score());
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("SelfComparison")
|
||||||
|
public void testNeighbor() {
|
||||||
|
Neighbor a = new Neighbor(1, 10);
|
||||||
|
Neighbor b = new Neighbor(2, 20);
|
||||||
|
Neighbor c = new Neighbor(3, 20);
|
||||||
|
assertEquals(0, a.compareTo(a));
|
||||||
|
assertEquals(-1, a.compareTo(b));
|
||||||
|
assertEquals(1, b.compareTo(a));
|
||||||
|
assertEquals(1, b.compareTo(c));
|
||||||
|
assertEquals(-1, c.compareTo(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[] randomVector(Random random, int dim) {
|
||||||
|
float[] vec = new float[dim];
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
vec[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produces random vectors and caches them for random-access.
|
||||||
|
*/
|
||||||
|
class RandomVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||||
|
|
||||||
|
private final int dimension;
|
||||||
|
private final float[][] denseValues;
|
||||||
|
private final float[][] values;
|
||||||
|
private final float[] scratch;
|
||||||
|
private final SearchStrategy searchStrategy;
|
||||||
|
|
||||||
|
final int numVectors;
|
||||||
|
final int maxDoc;
|
||||||
|
|
||||||
|
private int pos = -1;
|
||||||
|
|
||||||
|
RandomVectorValues(int size, int dimension, Random random) {
|
||||||
|
this.dimension = dimension;
|
||||||
|
values = new float[size][];
|
||||||
|
denseValues = new float[size][];
|
||||||
|
scratch = new float[dimension];
|
||||||
|
int sz = 0;
|
||||||
|
int md = -1;
|
||||||
|
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||||
|
values[offset] = randomVector(random, dimension);
|
||||||
|
denseValues[sz++] = values[offset];
|
||||||
|
md = offset;
|
||||||
|
}
|
||||||
|
numVectors = sz;
|
||||||
|
maxDoc = md;
|
||||||
|
// get a random SearchStrategy other than NONE (0)
|
||||||
|
searchStrategy = SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
private RandomVectorValues(int dimension, SearchStrategy searchStrategy, float[][] denseValues, float[][] values, int size) {
|
||||||
|
this.dimension = dimension;
|
||||||
|
this.searchStrategy = searchStrategy;
|
||||||
|
this.values = values;
|
||||||
|
this.denseValues = denseValues;
|
||||||
|
scratch = new float[dimension];
|
||||||
|
numVectors = size;
|
||||||
|
maxDoc = values.length - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomVectorValues copy() {
|
||||||
|
return new RandomVectorValues(dimension, searchStrategy, denseValues, values, numVectors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return numVectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SearchStrategy searchStrategy() {
|
||||||
|
return searchStrategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue() {
|
||||||
|
if(random().nextBoolean()) {
|
||||||
|
return values[pos];
|
||||||
|
} else {
|
||||||
|
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
|
||||||
|
// This should help us catch cases of aliasing where the same VectorValues source is used twice in a
|
||||||
|
// single computation.
|
||||||
|
System.arraycopy(values[pos], 0, scratch, 0, dimension);
|
||||||
|
return scratch;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues randomAccess() {
|
||||||
|
return copy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue(int targetOrd) {
|
||||||
|
return denseValues[targetOrd];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef binaryValue(int targetOrd) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs search(float[] target, int k, int fanout) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean seek(int target) {
|
||||||
|
if (target >= 0 && target < values.length && values[target] != null) {
|
||||||
|
pos = target;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() {
|
||||||
|
return advance(pos + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int advance(int target) {
|
||||||
|
while (++pos < values.length) {
|
||||||
|
if (seek(pos)) {
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long cost() {
|
||||||
|
return size();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testBoundsCheckerMax() {
|
||||||
|
BoundsChecker max = BoundsChecker.create(false);
|
||||||
|
float f = random().nextFloat() - 0.5f;
|
||||||
|
// any float > -MAX_VALUE is in bounds
|
||||||
|
assertFalse(max.check(f));
|
||||||
|
// f is now the bound (minus some delta)
|
||||||
|
max.update(f);
|
||||||
|
assertFalse(max.check(f)); // f is not out of bounds
|
||||||
|
assertFalse(max.check(f + 1)); // anything greater than f is in bounds
|
||||||
|
assertTrue(max.check(f - 1e-5f)); // delta is zero initially
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testBoundsCheckerMin() {
|
||||||
|
BoundsChecker min = BoundsChecker.create(true);
|
||||||
|
float f = random().nextFloat() - 0.5f;
|
||||||
|
// any float < MAX_VALUE is in bounds
|
||||||
|
assertFalse(min.check(f));
|
||||||
|
// f is now the bound (minus some delta)
|
||||||
|
min.update(f);
|
||||||
|
assertFalse(min.check(f)); // f is not out of bounds
|
||||||
|
assertFalse(min.check(f - 1)); // anything less than f is in bounds
|
||||||
|
assertTrue(min.check(f + 1e-5f)); // delta is zero initially
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -99,7 +99,7 @@ public final class IntervalQuery extends Query {
|
||||||
private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) {
|
private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) {
|
||||||
Objects.requireNonNull(field, "null field aren't accepted");
|
Objects.requireNonNull(field, "null field aren't accepted");
|
||||||
Objects.requireNonNull(intervalsSource, "null intervalsSource aren't accepted");
|
Objects.requireNonNull(intervalsSource, "null intervalsSource aren't accepted");
|
||||||
Objects.requireNonNull(scoreFunction, "null scoreFunction aren't accepted");
|
Objects.requireNonNull(scoreFunction, "null searchStrategy aren't accepted");
|
||||||
this.field = field;
|
this.field = field;
|
||||||
this.intervalsSource = intervalsSource;
|
this.intervalsSource = intervalsSource;
|
||||||
this.scoreFunction = scoreFunction;
|
this.scoreFunction = scoreFunction;
|
||||||
|
|
Loading…
Reference in New Issue