LUCENE-9004: KNN vector search using NSW graphs (#2022)

This commit is contained in:
Michael Sokolov 2020-11-13 08:53:51 -05:00 committed by GitHub
parent 80a0154d57
commit b36b4af22b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 2400 additions and 99 deletions

View File

@ -283,7 +283,6 @@ public abstract class VectorWriter implements Closeable {
public BytesRef binaryValue(int targetOrd) throws IOException { public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
} }
} }
} }

View File

@ -27,15 +27,17 @@ import org.apache.lucene.index.SegmentWriteState;
/** /**
* Lucene 9.0 vector format, which encodes dense numeric vector values. * Lucene 9.0 vector format, which encodes dense numeric vector values.
* TODO: add support for approximate KNN search. *
* @lucene.experimental
*/ */
public final class Lucene90VectorFormat extends VectorFormat { public final class Lucene90VectorFormat extends VectorFormat {
static final String META_CODEC_NAME = "Lucene90VectorFormatMeta"; static final String META_CODEC_NAME = "Lucene90VectorFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene90VectorFormatData"; static final String VECTOR_DATA_CODEC_NAME = "Lucene90VectorFormatData";
static final String VECTOR_INDEX_CODEC_NAME = "Lucene90VectorFormatIndex";
static final String META_EXTENSION = "vem"; static final String META_EXTENSION = "vem";
static final String VECTOR_DATA_EXTENSION = "vec"; static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex";
static final int VERSION_START = 0; static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START; static final int VERSION_CURRENT = VERSION_START;

View File

@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Random;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.VectorReader; import org.apache.lucene.codecs.VectorReader;
@ -29,19 +30,28 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.Neighbor;
import org.apache.lucene.util.hnsw.Neighbors;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** /**
* Reads vectors from the index segments. * Reads vectors from the index segments along with index data structures supporting KNN search.
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene90VectorReader extends VectorReader { public final class Lucene90VectorReader extends VectorReader {
@ -49,13 +59,21 @@ public final class Lucene90VectorReader extends VectorReader {
private final FieldInfos fieldInfos; private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>(); private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final int maxDoc; private final IndexInput vectorIndex;
private final long checksumSeed;
Lucene90VectorReader(SegmentReadState state) throws IOException { Lucene90VectorReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos; this.fieldInfos = state.fieldInfos;
this.maxDoc = state.segmentInfo.maxDoc();
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.META_EXTENSION); int versionMeta = readMetadata(state, Lucene90VectorFormat.META_EXTENSION);
long[] checksumRef = new long[1];
vectorData = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_DATA_EXTENSION, Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, checksumRef);
vectorIndex = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION, Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME, checksumRef);
checksumSeed = checksumRef[0];
}
private int readMetadata(SegmentReadState state, String fileExtension) throws IOException {
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
int versionMeta = -1; int versionMeta = -1;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) { try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) {
Throwable priorE = null; Throwable priorE = null;
@ -73,29 +91,32 @@ public final class Lucene90VectorReader extends VectorReader {
CodecUtil.checkFooter(meta, priorE); CodecUtil.checkFooter(meta, priorE);
} }
} }
return versionMeta;
}
private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, long[] checksumRef) throws IOException {
boolean success = false; boolean success = false;
String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION); String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
this.vectorData = state.directory.openInput(vectorDataFileName, state.context); IndexInput in = state.directory.openInput(fileName, state.context);
try { try {
int versionVectorData = CodecUtil.checkIndexHeader(vectorData, int versionVectorData = CodecUtil.checkIndexHeader(in,
Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, codecName,
Lucene90VectorFormat.VERSION_START, Lucene90VectorFormat.VERSION_START,
Lucene90VectorFormat.VERSION_CURRENT, Lucene90VectorFormat.VERSION_CURRENT,
state.segmentInfo.getId(), state.segmentInfo.getId(),
state.segmentSuffix); state.segmentSuffix);
if (versionMeta != versionVectorData) { if (versionMeta != versionVectorData) {
throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", vector data=" + versionVectorData, vectorData); throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in);
} }
CodecUtil.retrieveChecksum(vectorData); checksumRef[0] = CodecUtil.retrieveChecksum(in);
success = true; success = true;
} finally { } finally {
if (!success) { if (!success) {
IOUtils.closeWhileHandlingException(this.vectorData); IOUtils.closeWhileHandlingException(in);
} }
} }
return in;
} }
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
@ -104,23 +125,28 @@ public final class Lucene90VectorReader extends VectorReader {
if (info == null) { if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
} }
int searchStrategyId = meta.readInt(); fields.put(info.name, readField(meta));
if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) { }
throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, meta); }
}
VectorValues.SearchStrategy searchStrategy = VectorValues.SearchStrategy.values()[searchStrategyId]; private VectorValues.SearchStrategy readSearchStrategy(DataInput input) throws IOException {
long vectorDataOffset = meta.readVLong(); int searchStrategyId = input.readInt();
long vectorDataLength = meta.readVLong(); if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) {
int dimension = meta.readInt(); throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, input);
int size = meta.readInt(); }
int[] ordToDoc = new int[size]; return VectorValues.SearchStrategy.values()[searchStrategyId];
for (int i = 0; i < size; i++) { }
int doc = meta.readVInt();
ordToDoc[i] = doc; private FieldEntry readField(DataInput input) throws IOException {
} VectorValues.SearchStrategy searchStrategy = readSearchStrategy(input);
FieldEntry fieldEntry = new FieldEntry(dimension, searchStrategy, maxDoc, vectorDataOffset, vectorDataLength, switch(searchStrategy) {
ordToDoc); case NONE:
fields.put(info.name, fieldEntry); return new FieldEntry(input, searchStrategy);
case DOT_PRODUCT_HNSW:
case EUCLIDEAN_HNSW:
return new HnswGraphFieldEntry(input, searchStrategy);
default:
throw new CorruptIndexException("Unknown vector search strategy: " + searchStrategy, input);
} }
} }
@ -137,6 +163,7 @@ public final class Lucene90VectorReader extends VectorReader {
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData); CodecUtil.checksumEntireFile(vectorData);
CodecUtil.checksumEntireFile(vectorIndex);
} }
@Override @Override
@ -167,29 +194,58 @@ public final class Lucene90VectorReader extends VectorReader {
return new OffHeapVectorValues(fieldEntry, bytesSlice); return new OffHeapVectorValues(fieldEntry, bytesSlice);
} }
public KnnGraphValues getGraphValues(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.indexDataLength > 0) {
return getGraphValues(entry);
} else {
return KnnGraphValues.EMPTY;
}
}
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
if (entry.searchStrategy.isHnsw()) {
HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry;
IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
return new IndexedKnnGraphReader(graphEntry, bytesSlice);
} else {
return KnnGraphValues.EMPTY;
}
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
vectorData.close(); IOUtils.close(vectorData, vectorIndex);
} }
private static class FieldEntry { private static class FieldEntry {
final int dimension; final int dimension;
final VectorValues.SearchStrategy searchStrategy; final VectorValues.SearchStrategy searchStrategy;
final int maxDoc;
final long vectorDataOffset; final long vectorDataOffset;
final long vectorDataLength; final long vectorDataLength;
final long indexDataOffset;
final long indexDataLength;
final int[] ordToDoc; final int[] ordToDoc;
FieldEntry(int dimension, VectorValues.SearchStrategy searchStrategy, int maxDoc, FieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException {
long vectorDataOffset, long vectorDataLength, int[] ordToDoc) {
this.dimension = dimension;
this.searchStrategy = searchStrategy; this.searchStrategy = searchStrategy;
this.maxDoc = maxDoc; vectorDataOffset = input.readVLong();
this.vectorDataOffset = vectorDataOffset; vectorDataLength = input.readVLong();
this.vectorDataLength = vectorDataLength; indexDataOffset = input.readVLong();
this.ordToDoc = ordToDoc; indexDataLength = input.readVLong();
dimension = input.readInt();
int size = input.readInt();
ordToDoc = new int[size];
for (int i = 0; i < size; i++) {
int doc = input.readVInt();
ordToDoc[i] = doc;
}
} }
int size() { int size() {
@ -197,6 +253,21 @@ public final class Lucene90VectorReader extends VectorReader {
} }
} }
private static class HnswGraphFieldEntry extends FieldEntry {
final long[] ordOffsets;
HnswGraphFieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException {
super(input, searchStrategy);
ordOffsets = new long[size()];
long offset = 0;
for (int i = 0; i < ordOffsets.length; i++) {
offset += input.readVLong();
ordOffsets[i] = offset;
}
}
}
/** Read the vector values from the index input. This supports both iterated and random access. */ /** Read the vector values from the index input. This supports both iterated and random access. */
private final class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValuesProducer { private final class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValuesProducer {
@ -252,11 +323,6 @@ public final class Lucene90VectorReader extends VectorReader {
return binaryValue; return binaryValue;
} }
@Override
public TopDocs search(float[] target, int k, int fanout) {
throw new UnsupportedOperationException();
}
@Override @Override
public int docID() { public int docID() {
return doc; return doc;
@ -288,6 +354,30 @@ public final class Lucene90VectorReader extends VectorReader {
return new OffHeapRandomAccess(dataIn.clone()); return new OffHeapRandomAccess(dataIn.clone());
} }
@Override
public TopDocs search(float[] vector, int topK, int fanout) throws IOException {
// use a seed that is fixed for the index so we get reproducible results for the same query
final Random random = new Random(checksumSeed);
Neighbors results = HnswGraph.search(vector, topK + fanout, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random);
while (results.size() > topK) {
results.pop();
}
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)];
boolean reversed = searchStrategy().reversed;
while (results.size() > 0) {
Neighbor n = results.pop();
float score;
if (reversed) {
score = (float) Math.exp(- n.score() / vector.length);
} else {
score = n.score();
}
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[n.node()], score);
}
// always return >= the case where we can assert == is only when there are fewer than topK vectors in the index
return new TopDocs(new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), scoreDocs);
}
class OffHeapRandomAccess implements RandomAccessVectorValues { class OffHeapRandomAccess implements RandomAccessVectorValues {
@ -296,12 +386,10 @@ public final class Lucene90VectorReader extends VectorReader {
final BytesRef binaryValue; final BytesRef binaryValue;
final ByteBuffer byteBuffer; final ByteBuffer byteBuffer;
final FloatBuffer floatBuffer; final FloatBuffer floatBuffer;
final int byteSize;
final float[] value; final float[] value;
OffHeapRandomAccess(IndexInput dataIn) { OffHeapRandomAccess(IndexInput dataIn) {
this.dataIn = dataIn; this.dataIn = dataIn;
byteSize = Float.BYTES * dimension();
byteBuffer = ByteBuffer.allocate(byteSize); byteBuffer = ByteBuffer.allocate(byteSize);
floatBuffer = byteBuffer.asFloatBuffer(); floatBuffer = byteBuffer.asFloatBuffer();
value = new float[dimension()]; value = new float[dimension()];
@ -342,7 +430,41 @@ public final class Lucene90VectorReader extends VectorReader {
dataIn.seek(offset); dataIn.seek(offset);
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
} }
}
}
/** Read the nearest-neighbors graph from the index input */
private final class IndexedKnnGraphReader extends KnnGraphValues {
final HnswGraphFieldEntry entry;
final IndexInput dataIn;
int arcCount;
int arcUpTo;
int arc;
IndexedKnnGraphReader(HnswGraphFieldEntry entry, IndexInput dataIn) {
this.entry = entry;
this.dataIn = dataIn;
}
@Override
public void seek(int targetOrd) throws IOException {
// unsafe; no bounds checking
dataIn.seek(entry.ordOffsets[targetOrd]);
arcCount = dataIn.readInt();
arc = -1;
arcUpTo = 0;
}
@Override
public int nextNeighbor() throws IOException {
if (arcUpTo >= arcCount) {
return NO_MORE_DOCS;
}
++arcUpTo;
arc += dataIn.readVInt();
return arc;
} }
} }
} }

View File

@ -18,18 +18,20 @@
package org.apache.lucene.codecs.lucene90; package org.apache.lucene.codecs.lucene90;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.VectorWriter; import org.apache.lucene.codecs.VectorWriter;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@ -39,7 +41,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
*/ */
public final class Lucene90VectorWriter extends VectorWriter { public final class Lucene90VectorWriter extends VectorWriter {
private final IndexOutput meta, vectorData; private final IndexOutput meta, vectorData, vectorIndex;
private boolean finished; private boolean finished;
@ -52,6 +54,9 @@ public final class Lucene90VectorWriter extends VectorWriter {
String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION); String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION);
vectorData = state.directory.createOutput(vectorDataFileName, state.context); vectorData = state.directory.createOutput(vectorDataFileName, state.context);
String indexDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION);
vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
try { try {
CodecUtil.writeIndexHeader(meta, CodecUtil.writeIndexHeader(meta,
Lucene90VectorFormat.META_CODEC_NAME, Lucene90VectorFormat.META_CODEC_NAME,
@ -61,6 +66,10 @@ public final class Lucene90VectorWriter extends VectorWriter {
Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME,
Lucene90VectorFormat.VERSION_CURRENT, Lucene90VectorFormat.VERSION_CURRENT,
state.segmentInfo.getId(), state.segmentSuffix); state.segmentInfo.getId(), state.segmentSuffix);
CodecUtil.writeIndexHeader(vectorIndex,
Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME,
Lucene90VectorFormat.VERSION_CURRENT,
state.segmentInfo.getId(), state.segmentSuffix);
} catch (IOException e) { } catch (IOException e) {
IOUtils.closeWhileHandlingException(this); IOUtils.closeWhileHandlingException(this);
} }
@ -69,17 +78,47 @@ public final class Lucene90VectorWriter extends VectorWriter {
@Override @Override
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException { public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
long vectorDataOffset = vectorData.getFilePointer(); long vectorDataOffset = vectorData.getFilePointer();
// TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index // TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index
List<Integer> docIds = new ArrayList<>(); int[] docIds = new int[vectors.size()];
int docV, ord = 0; int count = 0;
for (docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), ord++) { for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
// write vector
writeVectorValue(vectors); writeVectorValue(vectors);
docIds.add(docV); docIds[count] = docV;
// TODO: write knn graph value
} }
// count may be < vectors.size() e,g, if some documents were deleted
long[] offsets = new long[count];
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer();
if (vectors.searchStrategy().isHnsw()) {
if (vectors instanceof RandomAccessVectorValuesProducer) {
writeGraph(vectorIndex, (RandomAccessVectorValuesProducer) vectors, vectorIndexOffset, offsets, count);
} else {
throw new IllegalArgumentException("Indexing an HNSW graph requires a random access vector values, got " + vectors);
}
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
if (vectorDataLength > 0) { if (vectorDataLength > 0) {
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, vectorIndexOffset, vectorIndexLength, count, docIds);
if (vectors.searchStrategy().isHnsw()) {
writeGraphOffsets(meta, offsets);
}
}
}
private void writeMeta(FieldInfo field, long vectorDataOffset, long vectorDataLength, long indexDataOffset, long indexDataLength, int size, int[] docIds) throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorSearchStrategy().ordinal());
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
meta.writeVLong(indexDataOffset);
meta.writeVLong(indexDataLength);
meta.writeInt(field.getVectorDimension());
meta.writeInt(size);
for (int i = 0; i < size; i ++) {
// TODO: delta-encode, or write as bitset
meta.writeVInt(docIds[i]);
} }
} }
@ -90,16 +129,28 @@ public final class Lucene90VectorWriter extends VectorWriter {
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length); vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
} }
private void writeMeta(FieldInfo field, long vectorDataOffset, long vectorDataLength, List<Integer> docIds) throws IOException { private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException {
meta.writeInt(field.number); long last = 0;
meta.writeInt(field.getVectorSearchStrategy().ordinal()); for (long offset : offsets) {
meta.writeVLong(vectorDataOffset); out.writeVLong(offset - last);
meta.writeVLong(vectorDataLength); last = offset;
meta.writeInt(field.getVectorDimension()); }
meta.writeInt(docIds.size()); }
for (Integer docId : docIds) {
// TODO: delta-encode, or write as bitset private void writeGraph(IndexOutput graphData, RandomAccessVectorValuesProducer vectorValues, long graphDataOffset, long[] offsets, int count) throws IOException {
meta.writeVInt(docId); HnswGraph graph = HnswGraphBuilder.build(vectorValues);
for (int ord = 0; ord < count; ord++) {
// write graph
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
int[] arcs = graph.getNeighbors(ord);
Arrays.sort(arcs);
graphData.writeInt(arcs.length);
int lastArc = -1; // to make the assertion work?
for (int arc : arcs) {
assert arc > lastArc : "arcs out of order: " + lastArc + "," + arc;
graphData.writeVInt(arc - lastArc);
lastArc = arc;
}
} }
} }
@ -117,11 +168,12 @@ public final class Lucene90VectorWriter extends VectorWriter {
} }
if (vectorData != null) { if (vectorData != null) {
CodecUtil.writeFooter(vectorData); CodecUtil.writeFooter(vectorData);
CodecUtil.writeFooter(vectorIndex);
} }
} }
@Override @Override
public void close() throws IOException { public void close() throws IOException {
IOUtils.close(meta, vectorData); IOUtils.close(meta, vectorData, vectorIndex);
} }
} }

View File

@ -28,7 +28,7 @@ import org.apache.lucene.analysis.Analyzer; // javadocs
public interface IndexableFieldType { public interface IndexableFieldType {
/** True if the field's value should be stored */ /** True if the field's value should be stored */
public boolean stored(); boolean stored();
/** /**
* True if this field's value should be analyzed by the * True if this field's value should be analyzed by the
@ -39,7 +39,7 @@ public interface IndexableFieldType {
*/ */
// TODO: shouldn't we remove this? Whether/how a field is // TODO: shouldn't we remove this? Whether/how a field is
// tokenized is an impl detail under Field? // tokenized is an impl detail under Field?
public boolean tokenized(); boolean tokenized();
/** /**
* True if this field's indexed form should be also stored * True if this field's indexed form should be also stored
@ -52,7 +52,7 @@ public interface IndexableFieldType {
* This option is illegal if {@link #indexOptions()} returns * This option is illegal if {@link #indexOptions()} returns
* IndexOptions.NONE. * IndexOptions.NONE.
*/ */
public boolean storeTermVectors(); boolean storeTermVectors();
/** /**
* True if this field's token character offsets should also * True if this field's token character offsets should also
@ -61,7 +61,7 @@ public interface IndexableFieldType {
* This option is illegal if term vectors are not enabled for the field * This option is illegal if term vectors are not enabled for the field
* ({@link #storeTermVectors()} is false) * ({@link #storeTermVectors()} is false)
*/ */
public boolean storeTermVectorOffsets(); boolean storeTermVectorOffsets();
/** /**
* True if this field's token positions should also be stored * True if this field's token positions should also be stored
@ -70,7 +70,7 @@ public interface IndexableFieldType {
* This option is illegal if term vectors are not enabled for the field * This option is illegal if term vectors are not enabled for the field
* ({@link #storeTermVectors()} is false). * ({@link #storeTermVectors()} is false).
*/ */
public boolean storeTermVectorPositions(); boolean storeTermVectorPositions();
/** /**
* True if this field's token payloads should also be stored * True if this field's token payloads should also be stored
@ -79,7 +79,7 @@ public interface IndexableFieldType {
* This option is illegal if term vector positions are not enabled * This option is illegal if term vector positions are not enabled
* for the field ({@link #storeTermVectors()} is false). * for the field ({@link #storeTermVectors()} is false).
*/ */
public boolean storeTermVectorPayloads(); boolean storeTermVectorPayloads();
/** /**
* True if normalization values should be omitted for the field. * True if normalization values should be omitted for the field.
@ -87,42 +87,42 @@ public interface IndexableFieldType {
* This saves memory, but at the expense of scoring quality (length normalization * This saves memory, but at the expense of scoring quality (length normalization
* will be disabled), and if you omit norms, you cannot use index-time boosts. * will be disabled), and if you omit norms, you cannot use index-time boosts.
*/ */
public boolean omitNorms(); boolean omitNorms();
/** {@link IndexOptions}, describing what should be /** {@link IndexOptions}, describing what should be
* recorded into the inverted index */ * recorded into the inverted index */
public IndexOptions indexOptions(); IndexOptions indexOptions();
/** /**
* DocValues {@link DocValuesType}: how the field's value will be indexed * DocValues {@link DocValuesType}: how the field's value will be indexed
* into docValues. * into docValues.
*/ */
public DocValuesType docValuesType(); DocValuesType docValuesType();
/** /**
* If this is positive (representing the number of point dimensions), the field is indexed as a point. * If this is positive (representing the number of point dimensions), the field is indexed as a point.
*/ */
public int pointDimensionCount(); int pointDimensionCount();
/** /**
* The number of dimensions used for the index key * The number of dimensions used for the index key
*/ */
public int pointIndexDimensionCount(); int pointIndexDimensionCount();
/** /**
* The number of bytes in each dimension's values. * The number of bytes in each dimension's values.
*/ */
public int pointNumBytes(); int pointNumBytes();
/** /**
* The number of dimensions of the field's vector value * The number of dimensions of the field's vector value
*/ */
public int vectorDimension(); int vectorDimension();
/** /**
* The {@link VectorValues.SearchStrategy} of the field's vector value * The {@link VectorValues.SearchStrategy} of the field's vector value
*/ */
public VectorValues.SearchStrategy vectorSearchStrategy(); VectorValues.SearchStrategy vectorSearchStrategy();
/** /**
* Attributes for the field type. * Attributes for the field type.
@ -132,5 +132,5 @@ public interface IndexableFieldType {
* *
* @return Map * @return Map
*/ */
public Map<String, String> getAttributes(); Map<String, String> getAttributes();
} }

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/**
* Access to per-document neighbor lists in a (hierarchical) knn search graph.
* @lucene.experimental
*/
public abstract class KnnGraphValues {
/** Sole constructor */
protected KnnGraphValues() {}
/** Move the pointer to exactly {@code target}, the id of a node in the graph.
* After this method returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
* @param target must be a valid node in the graph, ie. &ge; 0 and &lt; {@link VectorValues#size()}.
*/
public abstract void seek(int target) throws IOException;
/**
* Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator.
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
*/
public abstract int nextNeighbor() throws IOException;
/** Empty graph value */
public static KnnGraphValues EMPTY = new KnnGraphValues() {
@Override
public int nextNeighbor() {
return NO_MORE_DOCS;
}
@Override
public void seek(int target) {
}
};
}

View File

@ -190,7 +190,7 @@ public final class SlowCodecReaderWrapper {
} }
}; };
} }
private static NormsProducer readerToNormsProducer(final LeafReader reader) { private static NormsProducer readerToNormsProducer(final LeafReader reader) {
return new NormsProducer() { return new NormsProducer() {

View File

@ -23,6 +23,9 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.squareDistance;
/** /**
* This class provides access to per-document floating point vector values indexed as {@link * This class provides access to per-document floating point vector values indexed as {@link
* org.apache.lucene.document.VectorField}. * org.apache.lucene.document.VectorField}.
@ -91,15 +94,59 @@ public abstract class VectorValues extends DocIdSetIterator {
* determine the nearest neighbors. * determine the nearest neighbors.
*/ */
public enum SearchStrategy { public enum SearchStrategy {
/** No search strategy is provided. Note: {@link VectorValues#search(float[], int, int)} /** No search strategy is provided. Note: {@link VectorValues#search(float[], int, int)}
* is not supported for fields specifying this strategy. */ * is not supported for fields specifying this strategy. */
NONE, NONE,
/** HNSW graph built using Euclidean distance */ /** HNSW graph built using Euclidean distance */
EUCLIDEAN_HNSW, EUCLIDEAN_HNSW(true),
/** HNSW graph buit using dot product */ /** HNSW graph buit using dot product */
DOT_PRODUCT_HNSW DOT_PRODUCT_HNSW;
/** If true, the scores associated with vector comparisons in this strategy are in reverse order; that is,
* lower scores represent more similar vectors. Otherwise, if false, higher scores represent more similar vectors.
*/
public final boolean reversed;
SearchStrategy(boolean reversed) {
this.reversed = reversed;
}
SearchStrategy() {
reversed = false;
}
/**
* Calculates a similarity score between the two vectors with a specified function.
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the strategy's score function applied to the two vectors
*/
public float compare(float[] v1, float[] v2) {
switch (this) {
case EUCLIDEAN_HNSW:
return squareDistance(v1, v2);
case DOT_PRODUCT_HNSW:
return dotProduct(v1, v2);
default:
throw new IllegalStateException("Incomparable search strategy: " + this);
}
}
/**
* Return true if vectors indexed using this strategy will be indexed using an HNSW graph
*/
public boolean isHnsw() {
switch (this) {
case EUCLIDEAN_HNSW:
case DOT_PRODUCT_HNSW:
return true;
default:
return false;
}
}
} }
/** /**

View File

@ -176,47 +176,48 @@ class VectorValuesWriter {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public long cost() {
return size();
}
@Override @Override
public TopDocs search(float[] target, int k, int fanout) { public TopDocs search(float[] target, int k, int fanout) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public long cost() {
return size();
}
@Override @Override
public RandomAccessVectorValues randomAccess() { public RandomAccessVectorValues randomAccess() {
// Must make a new delegate randomAccess so that we have our own distinct float[]
final RandomAccessVectorValues delegateRA = ((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
return new RandomAccessVectorValues() { return new RandomAccessVectorValues() {
@Override @Override
public int size() { public int size() {
return delegate.size(); return delegateRA.size();
} }
@Override @Override
public int dimension() { public int dimension() {
return delegate.dimension(); return delegateRA.dimension();
} }
@Override @Override
public SearchStrategy searchStrategy() { public SearchStrategy searchStrategy() {
return delegate.searchStrategy(); return delegateRA.searchStrategy();
} }
@Override @Override
public float[] vectorValue(int targetOrd) throws IOException { public float[] vectorValue(int targetOrd) throws IOException {
return randomAccess.vectorValue(ordMap[targetOrd]); return delegateRA.vectorValue(ordMap[targetOrd]);
} }
@Override @Override
public BytesRef binaryValue(int targetOrd) { public BytesRef binaryValue(int targetOrd) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
}; };
} }
} }
@ -252,7 +253,7 @@ class VectorValuesWriter {
@Override @Override
public RandomAccessVectorValues randomAccess() { public RandomAccessVectorValues randomAccess() {
return this; return new BufferedVectorValues(docsWithField, vectors, dimension, searchStrategy);
} }
@Override @Override

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
abstract class BoundsChecker {
float bound;
/**
* Update the bound if sample is better
*/
abstract void update(float sample);
/**
* Return whether the sample exceeds (is worse than) the bound
*/
abstract boolean check(float sample);
static BoundsChecker create(boolean reversed) {
if (reversed) {
return new Min();
} else {
return new Max();
}
}
static class Max extends BoundsChecker {
Max() {
bound = Float.NEGATIVE_INFINITY;
}
void update(float sample) {
if (sample > bound) {
bound = sample;
}
}
boolean check(float sample) {
return sample < bound;
}
}
static class Min extends BoundsChecker {
Min() {
bound = Float.POSITIVE_INFINITY;
}
void update(float sample) {
if (sample < bound) {
bound = sample;
}
}
boolean check(float sample) {
return sample > bound;
}
}
}

View File

@ -0,0 +1,223 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/**
* Navigable Small-world graph. Provides efficient approximate nearest neighbor
* search for high dimensional vectors. See <a href="https://doi.org/10.1016/j.is.2013.10.006">Approximate nearest
* neighbor algorithm based on navigable small world graphs [2014]</a> and <a
* href="https://arxiv.org/abs/1603.09320">this paper [2018]</a> for details.
*
* The nomenclature is a bit different here from what's used in those papers:
*
* <h3>Hyperparameters</h3>
* <ul>
* <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2012 paper; it controls the number of random entry points to sample.</li>
* <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst</code> in the 2016 paper. It is the number of
* nearest neighbor candidates to track while searching the graph for each newly inserted node.</li>
* <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls how many of the <code>efConst</code> neighbors are
* connected to the new node</li>
* <li><code>fanout</code> the fanout parameter of {@link VectorValues#search(float[], int, int)}
* is used to control the values of <code>numSeed</code> and <code>topK</code> that are passed to this API.
* Thus <code>fanout</code> is like a combination of <code>ef</code> (search beam width) from the 2016 paper and <code>m</code> from the 2014 paper.
* </li>
* </ul>
*
* <p>Note: The graph may be searched by multiple threads concurrently, but updates are not thread-safe. Also note: there is no notion of
* deletions. Document searching built on top of this must do its own deletion-filtering.</p>
*/
public final class HnswGraph {
private final int maxConn;
private final VectorValues.SearchStrategy searchStrategy;
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to HnswBuilder, and the
// node values are the ordinals of those vectors.
private final List<Neighbors> graph;
HnswGraph(int maxConn, VectorValues.SearchStrategy searchStrategy) {
graph = new ArrayList<>();
graph.add(Neighbors.create(maxConn, searchStrategy.reversed));
this.maxConn = maxConn;
this.searchStrategy = searchStrategy;
}
/**
* Searches for the nearest neighbors of a query vector.
* @param query search query vector
* @param topK the number of nodes to be returned
* @param numSeed the number of random entry points to sample
* @param vectors vector values
* @param graphValues the graph values. May represent the entire graph, or a level in a hierarchical graph.
* @param random a source of randomness, used for generating entry points to the graph
* @return a priority queue holding the neighbors found
*/
public static Neighbors search(float[] query, int topK, int numSeed, RandomAccessVectorValues vectors, KnnGraphValues graphValues,
Random random) throws IOException {
VectorValues.SearchStrategy searchStrategy = vectors.searchStrategy();
// TODO: use unbounded priority queue
TreeSet<Neighbor> candidates;
if (searchStrategy.reversed) {
candidates = new TreeSet<>(Comparator.reverseOrder());
} else {
candidates = new TreeSet<>();
}
int size = vectors.size();
for (int i = 0; i < numSeed && i < size; i++) {
int entryPoint = random.nextInt(size);
candidates.add(new Neighbor(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint))));
}
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
Set<Integer> visited = new HashSet<>();
// TODO: use PriorityQueue's sentinel optimization?
Neighbors results = Neighbors.create(topK, searchStrategy.reversed);
for (Neighbor c : candidates) {
visited.add(c.node());
results.insertWithOverflow(c);
}
// Set the bound to the worst current result and below reject any newly-generated candidates failing
// to exceed this bound
BoundsChecker bound = BoundsChecker.create(searchStrategy.reversed);
bound.bound = results.top().score();
while (candidates.size() > 0) {
// get the best candidate (closest or best scoring)
Neighbor c = candidates.pollLast();
if (results.size() >= topK) {
if (bound.check(c.score())) {
break;
}
}
graphValues.seek(c.node());
int friendOrd;
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
if (visited.contains(friendOrd)) {
continue;
}
visited.add(friendOrd);
float score = searchStrategy.compare(query, vectors.vectorValue(friendOrd));
if (results.size() < topK || bound.check(score) == false) {
Neighbor n = new Neighbor(friendOrd, score);
candidates.add(n);
results.insertWithOverflow(n);
bound.bound = results.top().score();
}
}
}
results.setVisitedCount(visited.size());
return results;
}
/**
* Returns the nodes connected to the given node by its outgoing neighborNodes in an unpredictable order. Each node inserted
* by HnswGraphBuilder corresponds to a vector, and the node is the vector's ordinal.
* @param node the node whose friends are returned
*/
public int[] getNeighbors(int node) {
Neighbors neighbors = graph.get(node);
int[] result = new int[neighbors.size()];
int i = 0;
for (Neighbor n : neighbors) {
result[i++] = n.node();
}
return result;
}
/** Connects two nodes symmetrically, limiting the maximum number of connections from either node.
* node1 must be less than node2 and must already have been inserted to the graph */
void connectNodes(int node1, int node2, float score) {
connect(node1, node2, score);
if (node2 == graph.size()) {
addNode();
}
connect(node2, node1, score);
}
KnnGraphValues getGraphValues() {
return new HnswGraphValues();
}
/**
* Makes a connection from the node to a neighbor, dropping the worst connection when maxConn is exceeded
* @param node1 node to connect *from*
* @param node2 node to connect *to*
* @param score searchStrategy.score() of the vectors associated with the two nodes
*/
boolean connect(int node1, int node2, float score) {
//System.out.println(" HnswGraph.connect " + node1 + " -> " + node2);
assert node1 >= 0 && node2 >= 0;
Neighbors nn = graph.get(node1);
assert nn != null;
if (nn.size() == maxConn) {
Neighbor top = nn.top();
if (score < top.score() == nn.reversed()) {
top.update(node2, score);
nn.updateTop();
return true;
}
} else {
nn.add(new Neighbor(node2, score));
return true;
}
return false;
}
int addNode() {
graph.add(Neighbors.create(maxConn, searchStrategy.reversed));
return graph.size() - 1;
}
/**
* Present this graph as KnnGraphValues, used for searching while inserting new nodes.
*/
private class HnswGraphValues extends KnnGraphValues {
private int arcUpTo;
private int[] neighborNodes;
@Override
public void seek(int targetNode) {
arcUpTo = 0;
neighborNodes = HnswGraph.this.getNeighbors(targetNode);
}
@Override
public int nextNeighbor() {
if (arcUpTo >= neighborNodes.length) {
return NO_MORE_DOCS;
}
return neighborNodes[arcUpTo++];
}
}
}

View File

@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.Random;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the hyperparameters.
*/
public final class HnswGraphBuilder {
// default random seed for level generation
private static final long DEFAULT_RAND_SEED = System.currentTimeMillis();
// expose for testing.
public static long randSeed = DEFAULT_RAND_SEED;
// These "default" hyperparameter settings are exposed (and nonfinal) to enable performance testing
// since the indexing API doesn't provide any control over them.
// default max connections per node
public static int DEFAULT_MAX_CONN = 16;
// default candidate list size
static int DEFAULT_BEAM_WIDTH = 16;
private final int maxConn;
private final int beamWidth;
private final BoundedVectorValues boundedVectors;
private final VectorValues.SearchStrategy searchStrategy;
private final HnswGraph hnsw;
private final Random random;
/**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using default
* hyperparameter settings, and returns the resulting graph.
* @param vectorValues the vectors whose relations are represented by the graph
*/
public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues) throws IOException {
HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues);
return builder.build(vectorValues.randomAccess());
}
/**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using the given
* hyperparameter settings, and returns the resulting graph.
* @param vectorValues the vectors whose relations are represented by the graph
* @param maxConn the number of connections to make when adding a new graph node; roughly speaking the graph fanout.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this to ensure repeatable construction.
*/
public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues, int maxConn, int beamWidth, long seed) throws IOException {
HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues, maxConn, beamWidth, seed);
return builder.build(vectorValues.randomAccess());
}
/**
* Reads all the vectors from two copies of a random access VectorValues. Providing two copies enables efficient retrieval
* without extra data copying, while avoiding collision of the returned values.
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet accessor for the vectors
*/
HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
if (vectors == boundedVectors.raDelegate) {
throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}
for (int node = 1; node < vectors.size(); node++) {
insert(vectors.vectorValue(node));
}
return hnsw;
}
/** Construct the builder with default configurations */
private HnswGraphBuilder(RandomAccessVectorValuesProducer vectors) {
this(vectors, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, randSeed);
}
/** Full constructor */
HnswGraphBuilder(RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) {
RandomAccessVectorValues vectorValues = vectors.randomAccess();
searchStrategy = vectorValues.searchStrategy();
if (searchStrategy == VectorValues.SearchStrategy.NONE) {
throw new IllegalStateException("No distance function");
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
boundedVectors = new BoundedVectorValues(vectorValues);
this.hnsw = new HnswGraph(maxConn, searchStrategy);
random = new Random(seed);
}
/** Inserts a doc with vector value to the graph */
private void insert(float[] value) throws IOException {
addGraphNode(value);
// add the vector value
boundedVectors.inc();
}
private void addGraphNode(float[] value) throws IOException {
KnnGraphValues graphValues = hnsw.getGraphValues();
Neighbors candidates = HnswGraph.search(value, beamWidth, 2 * beamWidth, boundedVectors, graphValues, random);
int node = hnsw.addNode();
// connect the nearest neighbors to the just inserted node
addNearestNeighbors(node, candidates);
}
private void addNearestNeighbors(int newNode, Neighbors neighbors) {
// connect the nearest neighbors, relying on the graph's Neighbors' priority queues to drop off distant neighbors
for (Neighbor neighbor : neighbors) {
if (hnsw.connect(newNode, neighbor.node(), neighbor.score())) {
hnsw.connect(neighbor.node(), newNode, neighbor.score());
}
}
}
/**
* Provides a random access VectorValues view over a delegate VectorValues, bounding the maximum ord.
* TODO: get rid of this, all it does is track a counter
*/
private static class BoundedVectorValues implements RandomAccessVectorValues {
final RandomAccessVectorValues raDelegate;
int size;
BoundedVectorValues(RandomAccessVectorValues delegate) {
raDelegate = delegate;
if (delegate.size() > 0) {
// we implicitly add the first node
size = 1;
}
}
void inc() {
++size;
}
@Override
public int size() {
return size;
}
@Override
public int dimension() { return raDelegate.dimension(); }
@Override
public VectorValues.SearchStrategy searchStrategy() {
return raDelegate.searchStrategy();
}
@Override
public float[] vectorValue(int target) throws IOException {
return raDelegate.vectorValue(target);
}
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
/** A neighbor node in the HNSW graph; holds the node ordinal and its distance score. */
public class Neighbor implements Comparable<Neighbor> {
private int node;
private float score;
public Neighbor(int node, float score) {
this.node = node;
this.score = score;
}
public int node() {
return node;
}
public float score() {
return score;
}
void update(int node, float score) {
this.node = node;
this.score = score;
}
@Override
public int compareTo(Neighbor o) {
if (score == o.score) {
return o.node - node;
} else {
assert node != o.node : "attempt to add the same node " + node + " twice with different scores: " + score + " != " + o.score;
return Float.compare(score, o.score);
}
}
@Override
public boolean equals(Object other) {
return other instanceof Neighbor
&& ((Neighbor) other).node == node;
}
@Override
public int hashCode() {
return 39 + 61 * node;
}
@Override
public String toString() {
return "(" + node + ", " + score + ")";
}
}

View File

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import org.apache.lucene.util.PriorityQueue;
/** Neighbors queue. */
public abstract class Neighbors extends PriorityQueue<Neighbor> {
public static Neighbors create(int maxSize, boolean reversed) {
if (reversed) {
return new ReverseNeighbors(maxSize);
} else {
return new ForwardNeighbors(maxSize);
}
}
public abstract boolean reversed();
// Used to track the number of neighbors visited during a single graph traversal
private int visitedCount;
private Neighbors(int maxSize) {
super(maxSize);
}
private static class ForwardNeighbors extends Neighbors {
ForwardNeighbors(int maxSize) {
super(maxSize);
}
@Override
protected boolean lessThan(Neighbor a, Neighbor b) {
if (a.score() == b.score()) {
return a.node() > b.node();
}
return a.score() < b.score();
}
@Override
public boolean reversed() { return false; }
}
private static class ReverseNeighbors extends Neighbors {
ReverseNeighbors(int maxSize) {
super(maxSize);
}
@Override
protected boolean lessThan(Neighbor a, Neighbor b) {
if (a.score() == b.score()) {
return a.node() > b.node();
}
return b.score() < a.score();
}
@Override
public boolean reversed() { return true; }
}
void setVisitedCount(int visitedCount) {
this.visitedCount = visitedCount;
}
public int visitedCount() {
return visitedCount;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Neighbors=[");
this.iterator().forEachRemaining(sb::append);
sb.append("]");
return sb.toString();
}
}

View File

@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Navigable Small-World graph, nominally Hierarchical but currently only has a single
* layer. Provides efficient approximate nearest neighbor search for high dimensional vectors.
*/
package org.apache.lucene.util.hnsw;

View File

@ -0,0 +1,352 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed;
/** Tests indexing of a knn-graph */
public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector";
@Before
public void setup() {
randSeed = random().nextLong();
}
/**
* Basic test of creating documents in a graph
*/
public void testBasic() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
int numDoc = atLeast(10);
int dimension = atLeast(3);
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
}
add(iw, i, values[i]);
}
assertConsistentGraph(iw, values);
}
}
public void testSingleDocument() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
float[][] values = new float[][]{new float[]{0, 1, 2}};
add(iw, 0, values[0]);
assertConsistentGraph(iw, values);
iw.commit();
assertConsistentGraph(iw, values);
}
}
/**
* Verify that the graph properties are preserved when merging
*/
public void testMerge() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
}
add(iw, i, values[i]);
if (random().nextInt(10) == 3) {
//System.out.println("commit @" + i);
iw.commit();
}
}
if (random().nextBoolean()) {
iw.forceMerge(1);
}
assertConsistentGraph(iw, values);
}
}
// TODO: testSorted
// TODO: testDeletions
/**
* Verify that searching does something reasonable
*/
public void testSearch() throws Exception {
try (Directory dir = newDirectory();
// don't allow random merges; they mess up the docid tie-breaking assertion
IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig().setCodec(Codec.forName("Lucene90")))) {
// Add a document for every cartesian point in an NxN square so we can
// easily know which are the nearest neighbors to every point. Insert by iterating
// using a prime number that is not a divisor of N*N so that we will hit each point once,
// and chosen so that points will be inserted in a deterministic
// but somewhat distributed pattern
int n = 5, stepSize = 17;
float[][] values = new float[n * n][];
int index = 0;
for (int i = 0; i < values.length; i++) {
// System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
values[i] = new float[]{index % n, index / n};
index = (index + stepSize) % (n * n);
add(iw, i, values[i]);
if (i == 13) {
// create 2 segments
iw.commit();
}
}
boolean forceMerge = random().nextBoolean();
//System.out.println("");
if (forceMerge) {
iw.forceMerge(1);
}
assertConsistentGraph(iw, values);
try (DirectoryReader dr = DirectoryReader.open(iw)) {
// results are ordered by score (descending) and docid (ascending);
// This is the insertion order:
// column major, origin at upper left
// 0 15 5 20 10
// 3 18 8 23 13
// 6 21 11 1 16
// 9 24 14 4 19
// 12 2 17 7 22
// For this small graph the "search" is exhaustive, so this mostly tests the APIs, the orientation of the
// various priority queues, the scoring function, but not so much the approximate KNN search algo
assertGraphSearch(new int[]{0, 15, 3, 18, 5}, new float[]{0f, 0.1f}, dr);
// test tiebreaking by docid
assertGraphSearch(new int[]{11, 1, 8, 14, 21}, new float[]{2, 2}, dr);
assertGraphSearch(new int[]{15, 18, 0, 3, 5},new float[]{0.3f, 0.8f}, dr);
}
}
}
private void assertGraphSearch(int[] expected, float[] vector, IndexReader reader) throws IOException {
TopDocs results = doKnnSearch(reader, vector, 5);
for (ScoreDoc doc : results.scoreDocs) {
// map docId to insertion id
int id = Integer.parseInt(reader.document(doc.doc).get("id"));
doc.doc = id;
}
assertResults(expected, results);
}
private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx: reader.leaves()) {
results[ctx.ord] = ctx.reader().getVectorValues(KNN_GRAPH_FIELD)
.search(vector, k, 10);
if (ctx.docBase > 0) {
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
doc.doc += ctx.docBase;
}
}
}
return TopDocs.merge(k, results);
}
private void assertResults(int[] expected, TopDocs results) {
assertEquals(results.toString(), expected.length, results.scoreDocs.length);
for (int i = expected.length - 1; i >= 0; i--) {
assertEquals(Arrays.toString(results.scoreDocs), expected[i], results.scoreDocs[i].doc);
}
}
// For each leaf, verify that its graph nodes are 1-1 with vectors, that the vectors are the expected values,
// and that the graph is fully connected and symmetric.
// NOTE: when we impose max-fanout on the graph it wil no longer be symmetric, but should still
// be fully connected. Is there any other invariant we can test? Well, we can check that max fanout
// is respected. We can test *desirable* properties of the graph like small-world (the graph diameter
// should be tightly bounded).
private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOException {
int totalGraphDocs = 0;
try (DirectoryReader dr = DirectoryReader.open(iw)) {
for (LeafReaderContext ctx: dr.leaves()) {
LeafReader reader = ctx.reader();
VectorValues vectorValues = reader.getVectorValues(KNN_GRAPH_FIELD);
Lucene90VectorReader vectorReader = ((Lucene90VectorReader) ((CodecReader) reader).getVectorReader());
if (vectorReader == null) {
continue;
}
KnnGraphValues graphValues = vectorReader.getGraphValues(KNN_GRAPH_FIELD);
assertTrue((vectorValues == null) == (graphValues == null));
if (vectorValues == null) {
continue;
}
int[][] graph = new int[reader.maxDoc()][];
boolean foundOrphan= false;
int graphSize = 0;
int node = -1;
for (int i = 0; i < reader.maxDoc(); i++) {
int nextDocWithVectors = vectorValues.advance(i);
//System.out.println("advanced to " + nextDocWithVectors);
while (i < nextDocWithVectors && i < reader.maxDoc()) {
int id = Integer.parseInt(reader.document(i).get("id"));
assertNull("document " + id + " has no vector, but was expected to", values[id]);
++i;
}
if (nextDocWithVectors == NO_MORE_DOCS) {
break;
}
int id = Integer.parseInt(reader.document(i).get("id"));
graphValues.seek(++node);
// documents with KnnGraphValues have the expected vectors
float[] scratch = vectorValues.vectorValue();
assertArrayEquals("vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
values[id], scratch, 0f);
// We collect neighbors for analysis below
List<Integer> friends = new ArrayList<>();
int arc;
while ((arc = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
friends.add(arc);
}
if (friends.size() == 0) {
//System.out.printf("knngraph @%d is singleton (advance returns %d)\n", i, nextWithNeighbors);
foundOrphan = true;
} else {
// NOTE: these friends are dense ordinals, not docIds.
int[] friendCopy = new int[friends.size()];
for (int j = 0; j < friends.size(); j++) {
friendCopy[j] = friends.get(j);
}
graph[graphSize] = friendCopy;
//System.out.printf("knngraph @%d => %s\n", i, Arrays.toString(graph[i]));
}
graphSize++;
}
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
if (foundOrphan) {
assertEquals("graph is not fully connected", 1, graphSize);
} else {
assertTrue("Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
}
// assert that the graph in each leaf is connected and undirected (ie links are reciprocated)
// assertReciprocal(graph);
assertConnected(graph);
assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN);
totalGraphDocs += graphSize;
}
}
int expectedCount = 0;
for (float[] friends : values) {
if (friends != null) {
++expectedCount;
}
}
assertEquals(expectedCount, totalGraphDocs);
}
private void assertMaxConn(int[][] graph, int maxConn) {
for (int i = 0; i < graph.length; i++) {
if (graph[i] != null) {
assert (graph[i].length <= maxConn);
for (int j = 0; j < graph[i].length; j++) {
int k = graph[i][j];
assertNotNull(graph[k]);
}
}
}
}
private void assertReciprocal(int[][] graph) {
// The graph is undirected: if a -> b then b -> a.
for (int i = 0; i < graph.length; i++) {
if (graph[i] != null) {
for (int j = 0; j < graph[i].length; j++) {
int k = graph[i][j];
assertNotNull(graph[k]);
assertTrue("" + i + "->" + k + " is not reciprocated", Arrays.binarySearch(graph[k], i) >= 0);
}
}
}
}
private void assertConnected(int[][] graph) {
// every node in the graph is reachable from every other node
Set<Integer> visited = new HashSet<>();
List<Integer> queue = new LinkedList<>();
int count = 0;
for (int[] entry : graph) {
if (entry != null) {
if (queue.isEmpty()) {
queue.add(entry[0]); // start from any node
//System.out.println("start at " + entry[0]);
}
++count;
}
}
while(queue.isEmpty() == false) {
int i = queue.remove(0);
assertNotNull("expected neighbors of " + i, graph[i]);
visited.add(i);
for (int j : graph[i]) {
if (visited.contains(j) == false) {
//System.out.println(" ... " + j);
queue.add(j);
}
}
}
// we visited each node exactly once
assertEquals("Attempted to walk entire graph but only visited " + visited.size(), count, visited.size());
}
private void add(IndexWriter iw, int id, float[] vector) throws IOException {
Document doc = new Document();
if (vector != null) {
// TODO: choose random search strategy
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW));
}
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
//System.out.println("add " + id + " " + Arrays.toString(vector));
iw.addDocument(doc);
}
}

View File

@ -601,6 +601,51 @@ public class TestVectorValues extends LuceneTestCase {
} }
} }
public void testIndexMultipleVectorFields() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
float[] v = new float[]{1};
doc.add(new VectorField("field1", v, SearchStrategy.EUCLIDEAN_HNSW));
doc.add(new VectorField("field2", new float[]{1, 2, 3}, SearchStrategy.NONE));
iw.addDocument(doc);
v[0] = 2;
iw.addDocument(doc);
doc = new Document();
doc.add(new VectorField("field3", new float[]{1, 2, 3}, SearchStrategy.DOT_PRODUCT_HNSW));
iw.addDocument(doc);
iw.forceMerge(1);
try (IndexReader reader = iw.getReader()) {
LeafReader leaf = reader.leaves().get(0).reader();
VectorValues vectorValues = leaf.getVectorValues("field1");
assertEquals(1, vectorValues.dimension());
assertEquals(2, vectorValues.size());
vectorValues.nextDoc();
assertEquals(1f, vectorValues.vectorValue()[0], 0);
vectorValues.nextDoc();
assertEquals(2f, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
VectorValues vectorValues2 = leaf.getVectorValues("field2");
assertEquals(3, vectorValues2.dimension());
assertEquals(2, vectorValues2.size());
vectorValues2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
vectorValues2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc());
VectorValues vectorValues3 = leaf.getVectorValues("field3");
assertEquals(3, vectorValues3.dimension());
assertEquals(1, vectorValues3.size());
vectorValues3.nextDoc();
assertEquals(1f, vectorValues3.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
}
}
}
/** /**
* Index random vectors, sometimes skipping documents, sometimes deleting a document, * Index random vectors, sometimes skipping documents, sometimes deleting a document,
* sometimes merging, sometimes sorting the index, * sometimes merging, sometimes sorting the index,

View File

@ -0,0 +1,494 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** For testing indexing and search performance of a knn-graph
*
* java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search .../vectors.bin
*/
public class KnnGraphTester {
private final static String KNN_FIELD = "knn";
private final static String ID_FIELD = "id";
private final static VectorValues.SearchStrategy SEARCH_STRATEGY = VectorValues.SearchStrategy.DOT_PRODUCT_HNSW;
private int numDocs;
private int dim;
private int topK;
private int numIters;
private int fanout;
private Path indexPath;
private boolean quiet;
private boolean reindex;
private int reindexTimeMsec;
@SuppressForbidden(reason="uses Random()")
private KnnGraphTester() {
// set defaults
numDocs = 1000;
numIters = 1000;
dim = 256;
topK = 100;
fanout = topK;
indexPath = Paths.get("knn_test_index");
}
public static void main(String... args) throws Exception {
new KnnGraphTester().run(args);
}
private void run(String... args) throws Exception {
String operation = null, docVectorsPath = null, queryPath = null;
for (int iarg = 0; iarg < args.length; iarg++) {
String arg = args[iarg];
switch(arg) {
case "-generate":
case "-search":
case "-check":
case "-stats":
if (operation != null) {
throw new IllegalArgumentException("Specify only one operation, not both " + arg + " and " + operation);
}
if (iarg == args.length - 1) {
throw new IllegalArgumentException("Operation " + arg + " requires a following pathname");
}
operation = arg;
docVectorsPath = args[++iarg];
if (operation.equals("-search")) {
queryPath = args[++iarg];
}
break;
case "-fanout":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-fanout requires a following number");
}
fanout = Integer.parseInt(args[++iarg]);
break;
case "-beamWidthIndex":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-beamWidthIndex requires a following number");
}
HnswGraphBuilder.DEFAULT_BEAM_WIDTH = Integer.parseInt(args[++iarg]);
break;
case "-maxConn":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-maxConn requires a following number");
}
HnswGraphBuilder.DEFAULT_MAX_CONN = Integer.parseInt(args[++iarg]);
break;
case "-dim":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-dim requires a following number");
}
dim = Integer.parseInt(args[++iarg]);
break;
case "-ndoc":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-ndoc requires a following number");
}
numDocs = Integer.parseInt(args[++iarg]);
break;
case "-niter":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-niter requires a following number");
}
numIters = Integer.parseInt(args[++iarg]);
break;
case "-reindex":
reindex = true;
break;
case "-forceMerge":
operation = "-forceMerge";
break;
case "-quiet":
quiet = true;
break;
default:
throw new IllegalArgumentException("unknown argument " + arg);
//usage();
}
}
if (operation == null) {
usage();
}
if (reindex) {
reindexTimeMsec = createIndex(Paths.get(docVectorsPath), indexPath);
}
switch (operation) {
case "-search":
testSearch(indexPath, Paths.get(queryPath), getNN(Paths.get(docVectorsPath), Paths.get(queryPath)));
break;
case "-forceMerge":
forceMerge();
break;
case "-stats":
printFanoutHist(indexPath);
break;
}
}
@SuppressForbidden(reason="Prints stuff")
private void printFanoutHist(Path indexPath) throws IOException {
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
// int[] globalHist = new int[reader.maxDoc()];
for (LeafReaderContext context : reader.leaves()) {
LeafReader leafReader = context.reader();
KnnGraphValues knnValues = ((Lucene90VectorReader) ((CodecReader) leafReader).getVectorReader()).getGraphValues(KNN_FIELD);
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
printGraphFanout(knnValues, leafReader.maxDoc());
}
}
}
@SuppressForbidden(reason="Prints stuff")
private void forceMerge() throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig()
.setOpenMode(IndexWriterConfig.OpenMode.APPEND);
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("Force merge index in " + indexPath);
try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) {
iw.forceMerge(1);
}
}
@SuppressForbidden(reason="Prints stuff")
private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException {
int min = Integer.MAX_VALUE, max = 0, total = 0;
int count = 0;
int[] leafHist = new int[numDocs];
for (int node = 0; node < numDocs; node++) {
knnValues.seek(node);
int n = 0;
while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
++n;
}
++leafHist[n];
max = Math.max(max, n);
min = Math.min(min, n);
if (n > 0) {
++count;
total += n;
}
}
System.out.printf("Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n", count, min, total / (float) count, max);
printHist(leafHist, max, count, 10);
}
@SuppressForbidden(reason="Prints stuff")
private void printHist(int[] hist, int max, int count, int nbuckets) {
System.out.print("%");
for (int i=0; i <= nbuckets; i ++) {
System.out.printf("%4d", i * 100 / nbuckets);
}
System.out.printf("\n %4d", hist[0]);
int total = 0, ibucket = 1;
for (int i = 1; i <= max && ibucket <= nbuckets; i++) {
total += hist[i];
while (total >= count * ibucket / nbuckets) {
System.out.printf("%4d", i);
++ibucket;
}
}
System.out.println();
}
@SuppressForbidden(reason="Prints stuff")
private void testSearch(Path indexPath, Path queryPath, int[][] nn) throws IOException {
TopDocs[] results = new TopDocs[numIters];
long elapsed, totalCpuTime, totalVisited = 0;
try (FileChannel q = FileChannel.open(queryPath)) {
FloatBuffer targets = q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] target = new float[dim];
if (quiet == false) {
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
}
long start;
ThreadMXBean bean = ManagementFactory.getThreadMXBean();
long cpuTimeStartNs;
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
for (int i = 0; i < 1000; i++) {
// warm up
targets.get(target);
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
}
targets.position(0);
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
targets.get(target);
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
}
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
for (int i = 0; i < numIters; i++) {
totalVisited += results[i].totalHits.value;
for (ScoreDoc doc : results[i].scoreDocs) {
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
}
}
}
if (quiet == false) {
System.out.println("completed " + numIters + " searches in " + elapsed + " ms: " + ((1000 * numIters) / elapsed) + " QPS "
+ "CPU time=" + totalCpuTime + "ms");
}
}
if (quiet == false) {
System.out.println("checking results");
}
float recall = checkResults(results, nn);
totalVisited /= numIters;
if (quiet) {
System.out.printf(Locale.ROOT, "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n", recall, totalCpuTime / (float) numIters,
numDocs, fanout, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, totalVisited, reindexTimeMsec);
}
}
private static TopDocs doKnnSearch(IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx: reader.leaves()) {
results[ctx.ord] = ctx.reader().getVectorValues(field).search(vector, k, fanout);
int docBase = ctx.docBase;
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
scoreDoc.doc += docBase;
}
}
return TopDocs.merge(k, results);
}
private float checkResults(TopDocs[] results, int[][] nn) {
int totalMatches = 0;
int totalResults = 0;
for (int i = 0; i < results.length; i++) {
int n = results[i].scoreDocs.length;
totalResults += n;
//System.out.println(Arrays.toString(nn[i]));
//System.out.println(Arrays.toString(results[i].scoreDocs));
totalMatches += compareNN(nn[i], results[i]);
}
if (quiet == false) {
System.out.println("total matches = " + totalMatches + " out of " + totalResults);
System.out.printf(Locale.ROOT, "Average overlap = %.2f%%\n", ((100.0 * totalMatches) / totalResults));
}
return totalMatches / (float) totalResults;
}
private int compareNN(int[] expected, TopDocs results) {
int matched = 0;
/*
System.out.print("expected=");
for (int j = 0; j < expected.length; j++) {
System.out.print(expected[j]);
System.out.print(", ");
}
System.out.print('\n');
System.out.println("results=");
for (int j = 0; j < results.scoreDocs.length; j++) {
System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", ");
}
System.out.print('\n');
*/
Set<Integer> expectedSet = new HashSet<>();
for (int i = 0; i < results.scoreDocs.length; i++) {
expectedSet.add(expected[i]);
}
for (ScoreDoc scoreDoc : results.scoreDocs) {
if (expectedSet.contains(scoreDoc.doc)) {
++matched;
}
}
return matched;
}
private int[][] getNN(Path docPath, Path queryPath) throws IOException {
// look in working directory for cached nn file
String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin";
Path nnPath = Paths.get(nnFileName);
if (Files.exists(nnPath)) {
return readNN(nnPath);
} else {
int[][] nn = computeNN(docPath, queryPath);
writeNN(nn, nnPath);
return nn;
}
}
private int[][] readNN(Path nnPath) throws IOException {
int[][] result = new int[numIters][];
try (FileChannel in = FileChannel.open(nnPath)) {
IntBuffer intBuffer = in.map(FileChannel.MapMode.READ_ONLY, 0, numIters * topK * Integer.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asIntBuffer();
for (int i = 0; i < numIters; i++) {
result[i] = new int[topK];
intBuffer.get(result[i]);
}
}
return result;
}
private void writeNN(int[][] nn, Path nnPath) throws IOException {
if (quiet == false) {
System.out.println("writing true nearest neighbors to " + nnPath);
}
ByteBuffer tmp = ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
try (OutputStream out = Files.newOutputStream(nnPath)) {
for (int i = 0; i < numIters; i++) {
tmp.asIntBuffer().put(nn[i]);
out.write(tmp.array());
}
}
}
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
int[][] result = new int[numIters][];
if (quiet == false) {
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
}
try (FileChannel in = FileChannel.open(docPath);
FileChannel qIn = FileChannel.open(queryPath)) {
FloatBuffer queries = qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] vector = new float[dim];
float[] query = new float[dim];
for (int i = 0; i < numIters; i++) {
queries.get(query);
long totalBytes = (long) numDocs * dim * Float.BYTES;
int blockSize = (int) Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)), offset = 0;
int j = 0;
//System.out.println("totalBytes=" + totalBytes);
while (j < numDocs) {
FloatBuffer vectors = in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
Neighbors queue = Neighbors.create(topK, SEARCH_STRATEGY.reversed);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = SEARCH_STRATEGY.compare(query, vector);
queue.insertWithOverflow(new Neighbor(j, d));
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
Neighbor n = queue.pop();
result[i][k] = n.node();
//System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
}
}
}
}
return result;
}
private int createIndex(Path docsPath, Path indexPath) throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig()
.setOpenMode(IndexWriterConfig.OpenMode.CREATE);
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
iwc.setRAMBufferSizeMB(1994d);
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
}
long start = System.nanoTime();
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
int blockSize = (int) Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES));
float[] vector = new float[dim];
try (FileChannel in = FileChannel.open(docsPath)) {
int i = 0;
while (i < numDocs) {
FloatBuffer vectors = in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
for (; vectors.hasRemaining() && i < numDocs ; i++) {
vectors.get(vector);
Document doc = new Document();
//System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
doc.add(new VectorField(KNN_FIELD, vector, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW));
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
}
}
if (quiet == false) {
System.out.println("Done indexing " + numDocs + " documents; now flush");
}
}
}
long elapsed = System.nanoTime() - start;
if (quiet == false) {
System.out.println("Indexed " + numDocs + " documents in " + elapsed / 1_000_000_000 + "s");
}
return (int) (elapsed / 1_000_000);
}
private static void usage() {
String error = "Usage: TestKnnGraph -generate|-search|-stats|-check {datafile} [-beamWidth N]";
System.err.println(error);
System.exit(1);
}
}

View File

@ -0,0 +1,459 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** Tests HNSW KNN graphs */
public class TestHnsw extends LuceneTestCase {
// test writing out and reading in a graph gives the same graph
public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1;
int nDoc = random().nextInt(100) + 1;
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
long seed = random().nextLong();
HnswGraphBuilder.randSeed = seed;
HnswGraph hnsw = HnswGraphBuilder.build((RandomAccessVectorValuesProducer) vectors);
// Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed;
try (Directory dir = newDirectory()) {
int nVec = 0, indexedDoc = 0;
// Don't merge randomly, create a single segment because we rely on the docid ordering for this test
IndexWriterConfig iwc = new IndexWriterConfig()
.setCodec(Codec.forName("Lucene90"));
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
while (v2.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < v2.docID()) {
// increment docId in the index by adding empty documents
iw.addDocument(new Document());
indexedDoc++;
}
Document doc = new Document();
doc.add(new VectorField("field", v2.vectorValue(), v2.searchStrategy));
doc.add(new StoredField("id", v2.docID()));
iw.addDocument(doc);
nVec++;
indexedDoc++;
}
}
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
VectorValues values = ctx.reader().getVectorValues("field");
assertEquals(vectors.searchStrategy, values.searchStrategy());
assertEquals(dim, values.dimension());
assertEquals(nVec, values.size());
assertEquals(indexedDoc, ctx.reader().maxDoc());
assertEquals(indexedDoc, ctx.reader().numDocs());
assertVectorsEqual(v3, values);
KnnGraphValues graphValues = ((Lucene90VectorReader) ((CodecReader) ctx.reader()).getVectorReader()).getGraphValues("field");
assertGraphEqual(hnsw.getGraphValues(), graphValues, nVec);
}
}
}
}
// Make sure we actually approximately find the closest k elements. Mostly this is about
// ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions
public void testAknn() throws IOException {
int nDoc = 100;
RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc);
HnswGraph hnsw = HnswGraphBuilder.build(vectors);
// run some searches
Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random());
int sum = 0;
for (Neighbor n : nn) {
sum += n.node();
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);
}
public void testMaxConnections() throws Exception {
// verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors
HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW);
graph.connectNodes(0, 1, 1);
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
graph.connectNodes(0, 2, 2);
assertArrayEquals(new int[]{2}, graph.getNeighbors(0));
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
graph.connectNodes(2, 3, 1);
assertArrayEquals(new int[]{2}, graph.getNeighbors(0));
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
assertArrayEquals(new int[]{2}, graph.getNeighbors(3));
graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
graph.connectNodes(0, 1, 1);
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
graph.connectNodes(0, 2, 2);
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
graph.connectNodes(2, 3, 1);
assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
assertArrayEquals(new int[]{3}, graph.getNeighbors(2));
assertArrayEquals(new int[]{2}, graph.getNeighbors(3));
}
/** Returns vectors evenly distributed around the unit circle.
*/
class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
private final int size;
private final float[] value;
int doc = -1;
CircularVectorValues(int size) {
this.size = size;
value = new float[2];
}
public CircularVectorValues copy() {
return new CircularVectorValues(size);
}
@Override
public SearchStrategy searchStrategy() {
return SearchStrategy.DOT_PRODUCT_HNSW;
}
@Override
public int dimension() {
return 2;
}
@Override
public int size() {
return size;
}
@Override
public float[] vectorValue() {
return vectorValue(doc);
}
@Override
public RandomAccessVectorValues randomAccess() {
return new CircularVectorValues(size);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
return advance(doc + 1);
}
@Override
public int advance(int target) {
if (target >= 0 && target < size) {
doc = target;
} else {
doc = NO_MORE_DOCS;
}
return doc;
}
@Override
public long cost() {
return size;
}
@Override
public float[] vectorValue(int ord) {
value[0] = (float) Math.cos(Math.PI * ord / (double) size);
value[1] = (float) Math.sin(Math.PI * ord / (double) size);
return value;
}
@Override
public BytesRef binaryValue(int ord) {
return null;
}
@Override
public TopDocs search(float[] target, int k, int fanout) {
return null;
}
}
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
for (int node = 0; node < size; node ++) {
g.seek(node);
h.seek(node);
assertEquals("arcs differ for node " + node, getNeighbors(g), getNeighbors(h));
}
}
private Set<Integer> getNeighbors(KnnGraphValues g) throws IOException {
Set<Integer> neighbors = new HashSet<>();
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
neighbors.add(n);
}
return neighbors;
}
private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException {
int uDoc, vDoc;
while (true) {
uDoc = u.nextDoc();
vDoc = v.nextDoc();
assertEquals(uDoc, vDoc);
if (uDoc == NO_MORE_DOCS) {
break;
}
assertArrayEquals("vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f);
}
}
public void testNeighbors() {
// make sure we have the sign correct
Neighbors nn = Neighbors.create(2, false);
Neighbor a = new Neighbor(1, 10);
Neighbor b = new Neighbor(2, 20);
Neighbor c = new Neighbor(3, 30);
assertNull(nn.insertWithOverflow(b));
assertNull(nn.insertWithOverflow(a));
assertSame(a, nn.insertWithOverflow(c));
assertEquals(20, (int) nn.top().score());
assertEquals(20, (int) nn.pop().score());
assertEquals(30, (int) nn.top().score());
assertEquals(30, (int) nn.pop().score());
Neighbors fn = Neighbors.create(2, true);
assertNull(fn.insertWithOverflow(b));
assertNull(fn.insertWithOverflow(a));
assertSame(c, fn.insertWithOverflow(c));
assertEquals(20, (int) fn.top().score());
assertEquals(20, (int) fn.pop().score());
assertEquals(10, (int) fn.top().score());
assertEquals(10, (int) fn.pop().score());
}
@SuppressWarnings("SelfComparison")
public void testNeighbor() {
Neighbor a = new Neighbor(1, 10);
Neighbor b = new Neighbor(2, 20);
Neighbor c = new Neighbor(3, 20);
assertEquals(0, a.compareTo(a));
assertEquals(-1, a.compareTo(b));
assertEquals(1, b.compareTo(a));
assertEquals(1, b.compareTo(c));
assertEquals(-1, c.compareTo(b));
}
private static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
}
return vec;
}
/**
* Produces random vectors and caches them for random-access.
*/
class RandomVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
private final int dimension;
private final float[][] denseValues;
private final float[][] values;
private final float[] scratch;
private final SearchStrategy searchStrategy;
final int numVectors;
final int maxDoc;
private int pos = -1;
RandomVectorValues(int size, int dimension, Random random) {
this.dimension = dimension;
values = new float[size][];
denseValues = new float[size][];
scratch = new float[dimension];
int sz = 0;
int md = -1;
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
values[offset] = randomVector(random, dimension);
denseValues[sz++] = values[offset];
md = offset;
}
numVectors = sz;
maxDoc = md;
// get a random SearchStrategy other than NONE (0)
searchStrategy = SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1];
}
private RandomVectorValues(int dimension, SearchStrategy searchStrategy, float[][] denseValues, float[][] values, int size) {
this.dimension = dimension;
this.searchStrategy = searchStrategy;
this.values = values;
this.denseValues = denseValues;
scratch = new float[dimension];
numVectors = size;
maxDoc = values.length - 1;
}
public RandomVectorValues copy() {
return new RandomVectorValues(dimension, searchStrategy, denseValues, values, numVectors);
}
@Override
public int size() {
return numVectors;
}
@Override
public SearchStrategy searchStrategy() {
return searchStrategy;
}
@Override
public int dimension() {
return dimension;
}
@Override
public float[] vectorValue() {
if(random().nextBoolean()) {
return values[pos];
} else {
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
// This should help us catch cases of aliasing where the same VectorValues source is used twice in a
// single computation.
System.arraycopy(values[pos], 0, scratch, 0, dimension);
return scratch;
}
}
@Override
public RandomAccessVectorValues randomAccess() {
return copy();
}
@Override
public float[] vectorValue(int targetOrd) {
return denseValues[targetOrd];
}
@Override
public BytesRef binaryValue(int targetOrd) {
return null;
}
@Override
public TopDocs search(float[] target, int k, int fanout) {
return null;
}
private boolean seek(int target) {
if (target >= 0 && target < values.length && values[target] != null) {
pos = target;
return true;
} else {
return false;
}
}
@Override
public int docID() {
return pos;
}
@Override
public int nextDoc() {
return advance(pos + 1);
}
public int advance(int target) {
while (++pos < values.length) {
if (seek(pos)) {
return pos;
}
}
return NO_MORE_DOCS;
}
@Override
public long cost() {
return size();
}
}
public void testBoundsCheckerMax() {
BoundsChecker max = BoundsChecker.create(false);
float f = random().nextFloat() - 0.5f;
// any float > -MAX_VALUE is in bounds
assertFalse(max.check(f));
// f is now the bound (minus some delta)
max.update(f);
assertFalse(max.check(f)); // f is not out of bounds
assertFalse(max.check(f + 1)); // anything greater than f is in bounds
assertTrue(max.check(f - 1e-5f)); // delta is zero initially
}
public void testBoundsCheckerMin() {
BoundsChecker min = BoundsChecker.create(true);
float f = random().nextFloat() - 0.5f;
// any float < MAX_VALUE is in bounds
assertFalse(min.check(f));
// f is now the bound (minus some delta)
min.update(f);
assertFalse(min.check(f)); // f is not out of bounds
assertFalse(min.check(f - 1)); // anything less than f is in bounds
assertTrue(min.check(f + 1e-5f)); // delta is zero initially
}
}

View File

@ -99,7 +99,7 @@ public final class IntervalQuery extends Query {
private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) { private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) {
Objects.requireNonNull(field, "null field aren't accepted"); Objects.requireNonNull(field, "null field aren't accepted");
Objects.requireNonNull(intervalsSource, "null intervalsSource aren't accepted"); Objects.requireNonNull(intervalsSource, "null intervalsSource aren't accepted");
Objects.requireNonNull(scoreFunction, "null scoreFunction aren't accepted"); Objects.requireNonNull(scoreFunction, "null searchStrategy aren't accepted");
this.field = field; this.field = field;
this.intervalsSource = intervalsSource; this.intervalsSource = intervalsSource;
this.scoreFunction = scoreFunction; this.scoreFunction = scoreFunction;