mirror of
https://github.com/apache/lucene.git
synced 2025-02-28 05:19:17 +00:00
LUCENE-10577: enable quantization of HNSW vectors to 8 bits (#1054)
* LUCENE-10577: enable supplying, storing, and comparing HNSW vectors with 8 bit precision
This commit is contained in:
parent
59a0917e25
commit
a693fe819b
@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.SegmentInfo;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataOutput;
|
||||
@ -214,6 +215,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat {
|
||||
pointIndexDimensionCount,
|
||||
pointNumBytes,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
isSoftDeletesField);
|
||||
} catch (IllegalStateException e) {
|
||||
|
@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||
|
@ -14,7 +14,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene90;
|
||||
package org.apache.lucene.backward_codecs.lucene90;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.SegmentInfo;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataOutput;
|
||||
@ -191,6 +192,7 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
|
||||
pointIndexDimensionCount,
|
||||
pointNumBytes,
|
||||
vectorDimension,
|
||||
VectorEncoding.FLOAT32,
|
||||
vectorDistFunc,
|
||||
isSoftDeletesField);
|
||||
infos[i].checkConsistency();
|
@ -18,6 +18,7 @@
|
||||
package org.apache.lucene.backward_codecs.lucene90;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
@ -36,6 +37,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
@ -277,6 +279,21 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
|
@ -17,6 +17,7 @@
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.CompoundFormat;
|
||||
import org.apache.lucene.codecs.DocValuesFormat;
|
||||
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||
|
@ -25,6 +25,7 @@ import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
@ -55,7 +56,7 @@ public final class Lucene91HnswGraphBuilder {
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final Lucene91BoundsChecker bound;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
private final HnswGraphSearcher<float[]> graphSearcher;
|
||||
|
||||
final Lucene91OnHeapHnswGraph hnsw;
|
||||
|
||||
@ -101,7 +102,8 @@ public final class Lucene91HnswGraphBuilder {
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
new HnswGraphSearcher<>(
|
||||
VectorEncoding.FLOAT32,
|
||||
similarityFunction,
|
||||
new NeighborQueue(beamWidth, true),
|
||||
new FixedBitSet(vectorValues.size()));
|
||||
|
@ -18,6 +18,7 @@
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
@ -34,8 +35,10 @@ import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
@ -244,6 +247,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||
target,
|
||||
k,
|
||||
vectorValues,
|
||||
VectorEncoding.FLOAT32,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraph(fieldEntry),
|
||||
getAcceptOrds(acceptDocs, fieldEntry),
|
||||
@ -265,6 +269,21 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
}
|
||||
|
||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
|
@ -144,8 +144,8 @@
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
||||
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
||||
* information about how the segment is sorted
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
|
||||
* contains metadata about the set of named fields used in the index.
|
||||
* <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
|
||||
* This contains metadata about the set of named fields used in the index.
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
||||
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
||||
* field names. These are used to store auxiliary information about the document, such as its
|
||||
@ -240,7 +240,7 @@
|
||||
* systems that frequently run out of file handles.</td>
|
||||
* </tr>
|
||||
* <tr>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||
* <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||
* <td>.fnm</td>
|
||||
* <td>Stores information about the fields</td>
|
||||
* </tr>
|
||||
|
@ -17,6 +17,7 @@
|
||||
package org.apache.lucene.backward_codecs.lucene92;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.CompoundFormat;
|
||||
import org.apache.lucene.codecs.DocValuesFormat;
|
||||
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||
|
@ -18,6 +18,7 @@
|
||||
package org.apache.lucene.backward_codecs.lucene92;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
@ -30,8 +31,10 @@ import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
@ -237,6 +240,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||
target,
|
||||
k,
|
||||
vectorValues,
|
||||
VectorEncoding.FLOAT32,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs),
|
||||
@ -258,6 +262,21 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
|
@ -144,8 +144,8 @@
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
||||
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
||||
* information about how the segment is sorted
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
|
||||
* contains metadata about the set of named fields used in the index.
|
||||
* <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
|
||||
* This contains metadata about the set of named fields used in the index.
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
||||
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
||||
* field names. These are used to store auxiliary information about the document, such as its
|
||||
@ -240,7 +240,7 @@
|
||||
* systems that frequently run out of file handles.</td>
|
||||
* </tr>
|
||||
* <tr>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||
* <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||
* <td>.fnm</td>
|
||||
* <td>Stores information about the fields</td>
|
||||
* </tr>
|
||||
|
@ -31,6 +31,7 @@ import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
@ -148,7 +149,8 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||
OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
||||
: writeGraph(
|
||||
offHeapVectors, VectorEncoding.FLOAT32, fieldInfo.getVectorSimilarityFunction());
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
writeMeta(
|
||||
fieldInfo,
|
||||
@ -266,13 +268,20 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
RandomAccessVectorValuesProducer vectorValues,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
|
@ -28,6 +28,7 @@ import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.SegmentInfo;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.Directory;
|
||||
@ -68,7 +69,8 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
||||
static final BytesRef INDEX_DIM_COUNT = new BytesRef(" index dimensional count ");
|
||||
static final BytesRef DIM_NUM_BYTES = new BytesRef(" dimensional num bytes ");
|
||||
static final BytesRef VECTOR_NUM_DIMS = new BytesRef(" vector number of dimensions ");
|
||||
static final BytesRef VECTOR_SEARCH_STRATEGY = new BytesRef(" vector search strategy ");
|
||||
static final BytesRef VECTOR_ENCODING = new BytesRef(" vector encoding ");
|
||||
static final BytesRef VECTOR_SIMILARITY = new BytesRef(" vector similarity ");
|
||||
static final BytesRef SOFT_DELETES = new BytesRef(" soft-deletes ");
|
||||
|
||||
@Override
|
||||
@ -156,8 +158,13 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
||||
int vectorNumDimensions = Integer.parseInt(readString(VECTOR_NUM_DIMS.length, scratch));
|
||||
|
||||
SimpleTextUtil.readLine(input, scratch);
|
||||
assert StringHelper.startsWith(scratch.get(), VECTOR_SEARCH_STRATEGY);
|
||||
String scoreFunction = readString(VECTOR_SEARCH_STRATEGY.length, scratch);
|
||||
assert StringHelper.startsWith(scratch.get(), VECTOR_ENCODING);
|
||||
String encoding = readString(VECTOR_ENCODING.length, scratch);
|
||||
VectorEncoding vectorEncoding = vectorEncoding(encoding);
|
||||
|
||||
SimpleTextUtil.readLine(input, scratch);
|
||||
assert StringHelper.startsWith(scratch.get(), VECTOR_SIMILARITY);
|
||||
String scoreFunction = readString(VECTOR_SIMILARITY.length, scratch);
|
||||
VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
|
||||
|
||||
SimpleTextUtil.readLine(input, scratch);
|
||||
@ -179,6 +186,7 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
||||
indexDimensionalCount,
|
||||
dimensionalNumBytes,
|
||||
vectorNumDimensions,
|
||||
vectorEncoding,
|
||||
vectorDistFunc,
|
||||
isSoftDeletesField);
|
||||
}
|
||||
@ -201,6 +209,10 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
||||
return DocValuesType.valueOf(dvType);
|
||||
}
|
||||
|
||||
public VectorEncoding vectorEncoding(String vectorEncoding) {
|
||||
return VectorEncoding.valueOf(vectorEncoding);
|
||||
}
|
||||
|
||||
public VectorSimilarityFunction distanceFunction(String scoreFunction) {
|
||||
return VectorSimilarityFunction.valueOf(scoreFunction);
|
||||
}
|
||||
@ -297,7 +309,11 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
||||
SimpleTextUtil.write(out, Integer.toString(fi.getVectorDimension()), scratch);
|
||||
SimpleTextUtil.writeNewline(out);
|
||||
|
||||
SimpleTextUtil.write(out, VECTOR_SEARCH_STRATEGY);
|
||||
SimpleTextUtil.write(out, VECTOR_ENCODING);
|
||||
SimpleTextUtil.write(out, fi.getVectorEncoding().name(), scratch);
|
||||
SimpleTextUtil.writeNewline(out);
|
||||
|
||||
SimpleTextUtil.write(out, VECTOR_SIMILARITY);
|
||||
SimpleTextUtil.write(out, fi.getVectorSimilarityFunction().name(), scratch);
|
||||
SimpleTextUtil.writeNewline(out);
|
||||
|
||||
|
@ -42,6 +42,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.BytesRefBuilder;
|
||||
@ -181,6 +182,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
int numDocs = (int) acceptDocs.cost();
|
||||
return search(field, target, k, BitSet.of(acceptDocs, numDocs), Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
IndexInput clone = dataIn.clone();
|
||||
|
@ -23,6 +23,7 @@ import org.apache.lucene.codecs.lucene90.tests.MockTermStateFactory;
|
||||
import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.ByteBuffersDataOutput;
|
||||
import org.apache.lucene.store.ByteBuffersIndexOutput;
|
||||
@ -116,6 +117,7 @@ public class TestBlockWriter extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
true);
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ import org.apache.lucene.index.ImpactsEnum;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.ByteBuffersDirectory;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
@ -203,6 +204,7 @@ public class TestSTBlockReader extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false);
|
||||
}
|
||||
|
@ -20,8 +20,12 @@ package org.apache.lucene.codecs;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
|
||||
/** Vectors' writer for a field */
|
||||
public abstract class KnnFieldVectorsWriter implements Accountable {
|
||||
/**
|
||||
* Vectors' writer for a field
|
||||
*
|
||||
* @param <T> an array type; the type of vectors to be written
|
||||
*/
|
||||
public abstract class KnnFieldVectorsWriter<T> implements Accountable {
|
||||
|
||||
/** Sole constructor */
|
||||
protected KnnFieldVectorsWriter() {}
|
||||
@ -30,5 +34,13 @@ public abstract class KnnFieldVectorsWriter implements Accountable {
|
||||
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
||||
* increasing order.
|
||||
*/
|
||||
public abstract void addValue(int docID, float[] vectorValue) throws IOException;
|
||||
public abstract void addValue(int docID, Object vectorValue) throws IOException;
|
||||
|
||||
/**
|
||||
* Used to copy values being indexed to internal storage.
|
||||
*
|
||||
* @param vectorValue an array containing the vector value to add
|
||||
* @return a copy of the value; a new array
|
||||
*/
|
||||
public abstract T copyValue(T vectorValue);
|
||||
}
|
||||
|
@ -18,9 +18,11 @@
|
||||
package org.apache.lucene.codecs;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.util.Bits;
|
||||
@ -76,6 +78,15 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
||||
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
|
||||
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the current KnnVectorsFormat version number. Indexes written using the format will be
|
||||
* "stamped" with this version.
|
||||
*/
|
||||
public int currentVersion() {
|
||||
// return the version supported by older codecs that did not override this method
|
||||
return Lucene94HnswVectorsFormat.VERSION_START;
|
||||
}
|
||||
|
||||
/**
|
||||
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
|
||||
* support vectors.
|
||||
@ -104,6 +115,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
|
@ -20,12 +20,16 @@ package org.apache.lucene.codecs;
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** Reads vectors from an index. */
|
||||
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
@ -75,11 +79,39 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is exact, guaranteeing the true k closest neighbors will be returned. Typically
|
||||
* this requires an exhaustive scan of the entire index. It is intended to be used when the number
|
||||
* of potential matches is limited.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
|
||||
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
|
||||
* contains the number of documents visited during the search. If the search stopped early because
|
||||
* it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||
*
|
||||
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
|
||||
* FieldInfo}. The return value is never {@code null}.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match.
|
||||
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||
* that called {@link #getMergeInstance()}.
|
||||
@ -89,4 +121,67 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
public KnnVectorsReader getMergeInstance() {
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@link #searchExhaustively} */
|
||||
protected static TopDocs exhaustiveSearch(
|
||||
VectorValues vectorValues,
|
||||
DocIdSetIterator acceptDocs,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
float[] target,
|
||||
int k)
|
||||
throws IOException {
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int vectorDoc = vectorValues.advance(doc);
|
||||
assert vectorDoc == doc;
|
||||
float score = similarityFunction.compare(vectorValues.vectorValue(), target);
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
return topDocsFromHitQueue(queue, acceptDocs.cost());
|
||||
}
|
||||
|
||||
/** {@link #searchExhaustively} */
|
||||
protected static TopDocs exhaustiveSearch(
|
||||
VectorValues vectorValues,
|
||||
DocIdSetIterator acceptDocs,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
BytesRef target,
|
||||
int k)
|
||||
throws IOException {
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int vectorDoc = vectorValues.advance(doc);
|
||||
assert vectorDoc == doc;
|
||||
float score = similarityFunction.compare(vectorValues.binaryValue(), target);
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
return topDocsFromHitQueue(queue, acceptDocs.cost());
|
||||
}
|
||||
|
||||
private static TopDocs topDocsFromHitQueue(HitQueue queue, long numHits) {
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(numHits, TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
}
|
||||
}
|
||||
|
@ -37,14 +37,15 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||
protected KnnVectorsWriter() {}
|
||||
|
||||
/** Add new field for indexing */
|
||||
public abstract KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException;
|
||||
public abstract KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;
|
||||
|
||||
/** Flush all buffered data on disk * */
|
||||
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
|
||||
|
||||
/** Write field for merging */
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
KnnFieldVectorsWriter writer = addField(fieldInfo);
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
|
||||
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedValues.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene94;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FilterVectorValues;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** reads from byte-encoded data */
|
||||
public class ExpandingVectorValues extends FilterVectorValues {
|
||||
|
||||
private final float[] value;
|
||||
|
||||
/** @param in the wrapped values */
|
||||
protected ExpandingVectorValues(VectorValues in) {
|
||||
super(in);
|
||||
value = new float[in.dimension()];
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
BytesRef binaryValue = binaryValue();
|
||||
byte[] bytes = binaryValue.bytes;
|
||||
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
|
||||
value[i] = bytes[j];
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||
@ -69,7 +68,7 @@ public class Lucene94Codec extends Codec {
|
||||
}
|
||||
|
||||
private final TermVectorsFormat vectorsFormat = new Lucene90TermVectorsFormat();
|
||||
private final FieldInfosFormat fieldInfosFormat = new Lucene90FieldInfosFormat();
|
||||
private final FieldInfosFormat fieldInfosFormat = new Lucene94FieldInfosFormat();
|
||||
private final SegmentInfoFormat segmentInfosFormat = new Lucene90SegmentInfoFormat();
|
||||
private final LiveDocsFormat liveDocsFormat = new Lucene90LiveDocsFormat();
|
||||
private final CompoundFormat compoundFormat = new Lucene90CompoundFormat();
|
||||
@ -100,6 +99,11 @@ public class Lucene94Codec extends Codec {
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return Lucene94Codec.this.getKnnVectorsFormatForField(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int currentVersion() {
|
||||
return Lucene94HnswVectorsFormat.VERSION_CURRENT;
|
||||
}
|
||||
};
|
||||
|
||||
private final StoredFieldsFormat storedFieldsFormat;
|
||||
|
@ -0,0 +1,385 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene94;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.DocValuesFormat;
|
||||
import org.apache.lucene.codecs.FieldInfosFormat;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.SegmentInfo;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataOutput;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
|
||||
/**
|
||||
* Lucene 9.0 Field Infos format.
|
||||
*
|
||||
* <p>Field names are stored in the field info file, with suffix <code>.fnm</code>.
|
||||
*
|
||||
* <p>FieldInfos (.fnm) --> Header,FieldsCount, <FieldName,FieldNumber,
|
||||
* FieldBits,DocValuesBits,DocValuesGen,Attributes,DimensionCount,DimensionNumBytes>
|
||||
* <sup>FieldsCount</sup>,Footer
|
||||
*
|
||||
* <p>Data types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>Header --> {@link CodecUtil#checkIndexHeader IndexHeader}
|
||||
* <li>FieldsCount --> {@link DataOutput#writeVInt VInt}
|
||||
* <li>FieldName --> {@link DataOutput#writeString String}
|
||||
* <li>FieldBits, IndexOptions, DocValuesBits --> {@link DataOutput#writeByte Byte}
|
||||
* <li>FieldNumber, DimensionCount, DimensionNumBytes --> {@link DataOutput#writeInt VInt}
|
||||
* <li>Attributes --> {@link DataOutput#writeMapOfStrings Map<String,String>}
|
||||
* <li>DocValuesGen --> {@link DataOutput#writeLong(long) Int64}
|
||||
* <li>Footer --> {@link CodecUtil#writeFooter CodecFooter}
|
||||
* </ul>
|
||||
*
|
||||
* Field Descriptions:
|
||||
*
|
||||
* <ul>
|
||||
* <li>FieldsCount: the number of fields in this file.
|
||||
* <li>FieldName: name of the field as a UTF-8 String.
|
||||
* <li>FieldNumber: the field's number. Note that unlike previous versions of Lucene, the fields
|
||||
* are not numbered implicitly by their order in the file, instead explicitly.
|
||||
* <li>FieldBits: a byte containing field options.
|
||||
* <ul>
|
||||
* <li>The low order bit (0x1) is one for fields that have term vectors stored, and zero for
|
||||
* fields without term vectors.
|
||||
* <li>If the second lowest order-bit is set (0x2), norms are omitted for the indexed field.
|
||||
* <li>If the third lowest-order bit is set (0x4), payloads are stored for the indexed
|
||||
* field.
|
||||
* </ul>
|
||||
* <li>IndexOptions: a byte containing index options.
|
||||
* <ul>
|
||||
* <li>0: not indexed
|
||||
* <li>1: indexed as DOCS_ONLY
|
||||
* <li>2: indexed as DOCS_AND_FREQS
|
||||
* <li>3: indexed as DOCS_AND_FREQS_AND_POSITIONS
|
||||
* <li>4: indexed as DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS
|
||||
* </ul>
|
||||
* <li>DocValuesBits: a byte containing per-document value types. The type recorded as two
|
||||
* four-bit integers, with the high-order bits representing <code>norms</code> options, and
|
||||
* the low-order bits representing {@code DocValues} options. Each four-bit integer can be
|
||||
* decoded as such:
|
||||
* <ul>
|
||||
* <li>0: no DocValues for this field.
|
||||
* <li>1: NumericDocValues. ({@link DocValuesType#NUMERIC})
|
||||
* <li>2: BinaryDocValues. ({@code DocValuesType#BINARY})
|
||||
* <li>3: SortedDocValues. ({@code DocValuesType#SORTED})
|
||||
* </ul>
|
||||
* <li>DocValuesGen is the generation count of the field's DocValues. If this is -1, there are no
|
||||
* DocValues updates to that field. Anything above zero means there are updates stored by
|
||||
* {@link DocValuesFormat}.
|
||||
* <li>Attributes: a key-value map of codec-private attributes.
|
||||
* <li>PointDimensionCount, PointNumBytes: these are non-zero only if the field is indexed as
|
||||
* points, e.g. using {@link org.apache.lucene.document.LongPoint}
|
||||
* <li>VectorDimension: it is non-zero if the field is indexed as vectors.
|
||||
* <li>VectorEncoding: a byte containing the encoding of vector values:
|
||||
* <ul>
|
||||
* <li>0: BYTE. Samples are stored as signed bytes
|
||||
* <li>1: FLOAT32. Samples are stored in IEEE 32-bit floating point format.
|
||||
* </ul>
|
||||
* <li>VectorSimilarityFunction: a byte containing distance function used for similarity
|
||||
* calculation.
|
||||
* <ul>
|
||||
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
|
||||
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
|
||||
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
|
||||
* </ul>
|
||||
* </ul>
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene94FieldInfosFormat extends FieldInfosFormat {
|
||||
|
||||
/** Sole constructor. */
|
||||
public Lucene94FieldInfosFormat() {}
|
||||
|
||||
@Override
|
||||
public FieldInfos read(
|
||||
Directory directory, SegmentInfo segmentInfo, String segmentSuffix, IOContext context)
|
||||
throws IOException {
|
||||
final String fileName =
|
||||
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
|
||||
try (ChecksumIndexInput input = directory.openChecksumInput(fileName, context)) {
|
||||
Throwable priorE = null;
|
||||
FieldInfo[] infos = null;
|
||||
try {
|
||||
CodecUtil.checkIndexHeader(
|
||||
input,
|
||||
Lucene94FieldInfosFormat.CODEC_NAME,
|
||||
Lucene94FieldInfosFormat.FORMAT_START,
|
||||
Lucene94FieldInfosFormat.FORMAT_CURRENT,
|
||||
segmentInfo.getId(),
|
||||
segmentSuffix);
|
||||
|
||||
final int size = input.readVInt(); // read in the size
|
||||
infos = new FieldInfo[size];
|
||||
|
||||
// previous field's attribute map, we share when possible:
|
||||
Map<String, String> lastAttributes = Collections.emptyMap();
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
String name = input.readString();
|
||||
final int fieldNumber = input.readVInt();
|
||||
if (fieldNumber < 0) {
|
||||
throw new CorruptIndexException(
|
||||
"invalid field number for field: " + name + ", fieldNumber=" + fieldNumber, input);
|
||||
}
|
||||
byte bits = input.readByte();
|
||||
boolean storeTermVector = (bits & STORE_TERMVECTOR) != 0;
|
||||
boolean omitNorms = (bits & OMIT_NORMS) != 0;
|
||||
boolean storePayloads = (bits & STORE_PAYLOADS) != 0;
|
||||
boolean isSoftDeletesField = (bits & SOFT_DELETES_FIELD) != 0;
|
||||
|
||||
final IndexOptions indexOptions = getIndexOptions(input, input.readByte());
|
||||
|
||||
// DV Types are packed in one byte
|
||||
final DocValuesType docValuesType = getDocValuesType(input, input.readByte());
|
||||
final long dvGen = input.readLong();
|
||||
Map<String, String> attributes = input.readMapOfStrings();
|
||||
// just use the last field's map if its the same
|
||||
if (attributes.equals(lastAttributes)) {
|
||||
attributes = lastAttributes;
|
||||
}
|
||||
lastAttributes = attributes;
|
||||
int pointDataDimensionCount = input.readVInt();
|
||||
int pointNumBytes;
|
||||
int pointIndexDimensionCount = pointDataDimensionCount;
|
||||
if (pointDataDimensionCount != 0) {
|
||||
pointIndexDimensionCount = input.readVInt();
|
||||
pointNumBytes = input.readVInt();
|
||||
} else {
|
||||
pointNumBytes = 0;
|
||||
}
|
||||
final int vectorDimension = input.readVInt();
|
||||
final VectorEncoding vectorEncoding = getVectorEncoding(input, input.readByte());
|
||||
final VectorSimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte());
|
||||
|
||||
try {
|
||||
infos[i] =
|
||||
new FieldInfo(
|
||||
name,
|
||||
fieldNumber,
|
||||
storeTermVector,
|
||||
omitNorms,
|
||||
storePayloads,
|
||||
indexOptions,
|
||||
docValuesType,
|
||||
dvGen,
|
||||
attributes,
|
||||
pointDataDimensionCount,
|
||||
pointIndexDimensionCount,
|
||||
pointNumBytes,
|
||||
vectorDimension,
|
||||
vectorEncoding,
|
||||
vectorDistFunc,
|
||||
isSoftDeletesField);
|
||||
infos[i].checkConsistency();
|
||||
} catch (IllegalStateException e) {
|
||||
throw new CorruptIndexException(
|
||||
"invalid fieldinfo for field: " + name + ", fieldNumber=" + fieldNumber, input, e);
|
||||
}
|
||||
}
|
||||
} catch (Throwable exception) {
|
||||
priorE = exception;
|
||||
} finally {
|
||||
CodecUtil.checkFooter(input, priorE);
|
||||
}
|
||||
return new FieldInfos(infos);
|
||||
}
|
||||
}
|
||||
|
||||
static {
|
||||
// We "mirror" DocValues enum values with the constants below; let's try to ensure if we add a
|
||||
// new DocValuesType while this format is
|
||||
// still used for writing, we remember to fix this encoding:
|
||||
assert DocValuesType.values().length == 6;
|
||||
}
|
||||
|
||||
private static byte docValuesByte(DocValuesType type) {
|
||||
switch (type) {
|
||||
case NONE:
|
||||
return 0;
|
||||
case NUMERIC:
|
||||
return 1;
|
||||
case BINARY:
|
||||
return 2;
|
||||
case SORTED:
|
||||
return 3;
|
||||
case SORTED_SET:
|
||||
return 4;
|
||||
case SORTED_NUMERIC:
|
||||
return 5;
|
||||
default:
|
||||
// BUG
|
||||
throw new AssertionError("unhandled DocValuesType: " + type);
|
||||
}
|
||||
}
|
||||
|
||||
private static DocValuesType getDocValuesType(IndexInput input, byte b) throws IOException {
|
||||
switch (b) {
|
||||
case 0:
|
||||
return DocValuesType.NONE;
|
||||
case 1:
|
||||
return DocValuesType.NUMERIC;
|
||||
case 2:
|
||||
return DocValuesType.BINARY;
|
||||
case 3:
|
||||
return DocValuesType.SORTED;
|
||||
case 4:
|
||||
return DocValuesType.SORTED_SET;
|
||||
case 5:
|
||||
return DocValuesType.SORTED_NUMERIC;
|
||||
default:
|
||||
throw new CorruptIndexException("invalid docvalues byte: " + b, input);
|
||||
}
|
||||
}
|
||||
|
||||
private static VectorEncoding getVectorEncoding(IndexInput input, byte b) throws IOException {
|
||||
if (b < 0 || b >= VectorEncoding.values().length) {
|
||||
throw new CorruptIndexException("invalid vector encoding: " + b, input);
|
||||
}
|
||||
return VectorEncoding.values()[b];
|
||||
}
|
||||
|
||||
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
|
||||
if (b < 0 || b >= VectorSimilarityFunction.values().length) {
|
||||
throw new CorruptIndexException("invalid distance function: " + b, input);
|
||||
}
|
||||
return VectorSimilarityFunction.values()[b];
|
||||
}
|
||||
|
||||
static {
|
||||
// We "mirror" IndexOptions enum values with the constants below; let's try to ensure if we add
|
||||
// a new IndexOption while this format is
|
||||
// still used for writing, we remember to fix this encoding:
|
||||
assert IndexOptions.values().length == 5;
|
||||
}
|
||||
|
||||
private static byte indexOptionsByte(IndexOptions indexOptions) {
|
||||
switch (indexOptions) {
|
||||
case NONE:
|
||||
return 0;
|
||||
case DOCS:
|
||||
return 1;
|
||||
case DOCS_AND_FREQS:
|
||||
return 2;
|
||||
case DOCS_AND_FREQS_AND_POSITIONS:
|
||||
return 3;
|
||||
case DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS:
|
||||
return 4;
|
||||
default:
|
||||
// BUG:
|
||||
throw new AssertionError("unhandled IndexOptions: " + indexOptions);
|
||||
}
|
||||
}
|
||||
|
||||
private static IndexOptions getIndexOptions(IndexInput input, byte b) throws IOException {
|
||||
switch (b) {
|
||||
case 0:
|
||||
return IndexOptions.NONE;
|
||||
case 1:
|
||||
return IndexOptions.DOCS;
|
||||
case 2:
|
||||
return IndexOptions.DOCS_AND_FREQS;
|
||||
case 3:
|
||||
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS;
|
||||
case 4:
|
||||
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS;
|
||||
default:
|
||||
// BUG
|
||||
throw new CorruptIndexException("invalid IndexOptions byte: " + b, input);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(
|
||||
Directory directory,
|
||||
SegmentInfo segmentInfo,
|
||||
String segmentSuffix,
|
||||
FieldInfos infos,
|
||||
IOContext context)
|
||||
throws IOException {
|
||||
final String fileName =
|
||||
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
|
||||
try (IndexOutput output = directory.createOutput(fileName, context)) {
|
||||
CodecUtil.writeIndexHeader(
|
||||
output,
|
||||
Lucene94FieldInfosFormat.CODEC_NAME,
|
||||
Lucene94FieldInfosFormat.FORMAT_CURRENT,
|
||||
segmentInfo.getId(),
|
||||
segmentSuffix);
|
||||
output.writeVInt(infos.size());
|
||||
for (FieldInfo fi : infos) {
|
||||
fi.checkConsistency();
|
||||
|
||||
output.writeString(fi.name);
|
||||
output.writeVInt(fi.number);
|
||||
|
||||
byte bits = 0x0;
|
||||
if (fi.hasVectors()) bits |= STORE_TERMVECTOR;
|
||||
if (fi.omitsNorms()) bits |= OMIT_NORMS;
|
||||
if (fi.hasPayloads()) bits |= STORE_PAYLOADS;
|
||||
if (fi.isSoftDeletesField()) bits |= SOFT_DELETES_FIELD;
|
||||
output.writeByte(bits);
|
||||
|
||||
output.writeByte(indexOptionsByte(fi.getIndexOptions()));
|
||||
|
||||
// pack the DV type and hasNorms in one byte
|
||||
output.writeByte(docValuesByte(fi.getDocValuesType()));
|
||||
output.writeLong(fi.getDocValuesGen());
|
||||
output.writeMapOfStrings(fi.attributes());
|
||||
output.writeVInt(fi.getPointDimensionCount());
|
||||
if (fi.getPointDimensionCount() != 0) {
|
||||
output.writeVInt(fi.getPointIndexDimensionCount());
|
||||
output.writeVInt(fi.getPointNumBytes());
|
||||
}
|
||||
output.writeVInt(fi.getVectorDimension());
|
||||
output.writeByte((byte) fi.getVectorEncoding().ordinal());
|
||||
output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal());
|
||||
}
|
||||
CodecUtil.writeFooter(output);
|
||||
}
|
||||
}
|
||||
|
||||
/** Extension of field infos */
|
||||
static final String EXTENSION = "fnm";
|
||||
|
||||
// Codec header
|
||||
static final String CODEC_NAME = "Lucene90FieldInfos";
|
||||
static final int FORMAT_START = 0;
|
||||
static final int FORMAT_CURRENT = FORMAT_START;
|
||||
|
||||
// Field flags
|
||||
static final byte STORE_TERMVECTOR = 0x1;
|
||||
static final byte OMIT_NORMS = 0x2;
|
||||
static final byte STORE_PAYLOADS = 0x4;
|
||||
static final byte SOFT_DELETES_FIELD = 0x8;
|
||||
}
|
@ -38,8 +38,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
* <p>For each field:
|
||||
*
|
||||
* <ul>
|
||||
* <li>Floating-point vector data ordered by field, document ordinal, and vector dimension. The
|
||||
* floats are stored in little-endian byte order
|
||||
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
|
||||
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
|
||||
* sample is stored as an IEEE float in little-endian byte order.
|
||||
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
||||
* note that only in sparse case
|
||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||
@ -89,7 +90,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
* <ul>
|
||||
* <li><b>[int]</b> the number of nodes on this level
|
||||
* <li><b>array[int]</b> for levels greater than 0 list of nodes on this level, stored as
|
||||
* the the level 0th nodes ordinals.
|
||||
* the level 0th nodes' ordinals.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*
|
||||
@ -104,8 +105,8 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
|
||||
static final String VECTOR_DATA_EXTENSION = "vec";
|
||||
static final String VECTOR_INDEX_EXTENSION = "vex";
|
||||
|
||||
static final int VERSION_START = 0;
|
||||
static final int VERSION_CURRENT = VERSION_START;
|
||||
public static final int VERSION_START = 0;
|
||||
public static final int VERSION_CURRENT = 1;
|
||||
|
||||
/** Default number of maximum connections per node */
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
@ -156,6 +157,11 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
|
||||
return new Lucene94HnswVectorsReader(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int currentVersion() {
|
||||
return VERSION_CURRENT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn="
|
||||
|
@ -18,6 +18,8 @@
|
||||
package org.apache.lucene.codecs.lucene94;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
@ -30,8 +32,10 @@ import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
@ -169,16 +173,23 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
+ fieldEntry.dimension);
|
||||
}
|
||||
|
||||
long numBytes = (long) fieldEntry.size() * dimension * Float.BYTES;
|
||||
int byteSize =
|
||||
switch (info.getVectorEncoding()) {
|
||||
case BYTE -> Byte.BYTES;
|
||||
case FLOAT32 -> Float.BYTES;
|
||||
};
|
||||
int numBytes = fieldEntry.size * dimension * byteSize;
|
||||
if (numBytes != fieldEntry.vectorDataLength) {
|
||||
throw new IllegalStateException(
|
||||
"Vector data length "
|
||||
+ fieldEntry.vectorDataLength
|
||||
+ " not matching size="
|
||||
+ fieldEntry.size()
|
||||
+ fieldEntry.size
|
||||
+ " * dim="
|
||||
+ dimension
|
||||
+ " * 4 = "
|
||||
+ " * byteSize="
|
||||
+ byteSize
|
||||
+ " = "
|
||||
+ numBytes);
|
||||
}
|
||||
}
|
||||
@ -193,9 +204,18 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
return VectorSimilarityFunction.values()[similarityFunctionId];
|
||||
}
|
||||
|
||||
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||
int encodingId = input.readInt();
|
||||
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||
}
|
||||
return VectorEncoding.values()[encodingId];
|
||||
}
|
||||
|
||||
private FieldEntry readField(IndexInput input) throws IOException {
|
||||
VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||
return new FieldEntry(input, similarityFunction);
|
||||
return new FieldEntry(input, vectorEncoding, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -216,7 +236,12 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
|
||||
return new ExpandingVectorValues(values);
|
||||
} else {
|
||||
return values;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -237,6 +262,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
target,
|
||||
k,
|
||||
vectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs),
|
||||
@ -258,6 +284,25 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// The field does not exist or does not index vectors
|
||||
return EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||
VectorValues vectorValues = getVectorValues(field);
|
||||
|
||||
return switch (fieldEntry.vectorEncoding) {
|
||||
case BYTE -> exhaustiveSearch(
|
||||
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
|
||||
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||
};
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
@ -286,6 +331,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
static class FieldEntry {
|
||||
|
||||
final VectorSimilarityFunction similarityFunction;
|
||||
final VectorEncoding vectorEncoding;
|
||||
final long vectorDataOffset;
|
||||
final long vectorDataLength;
|
||||
final long vectorIndexOffset;
|
||||
@ -315,8 +361,13 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||
final DirectMonotonicReader.Meta meta;
|
||||
final long addressesLength;
|
||||
|
||||
FieldEntry(IndexInput input, VectorSimilarityFunction similarityFunction) throws IOException {
|
||||
FieldEntry(
|
||||
IndexInput input,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.vectorEncoding = vectorEncoding;
|
||||
vectorDataOffset = input.readVLong();
|
||||
vectorDataLength = input.readVLong();
|
||||
vectorIndexOffset = input.readVLong();
|
||||
|
@ -65,7 +65,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
|
||||
private final List<FieldWriter> fields = new ArrayList<>();
|
||||
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||
private boolean finished;
|
||||
|
||||
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
|
||||
@ -121,15 +121,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter newField = new FieldWriter(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter<?> newField =
|
||||
FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter field : fields) {
|
||||
for (FieldWriter<?> field : fields) {
|
||||
if (sortMap == null) {
|
||||
writeField(field, maxDoc);
|
||||
} else {
|
||||
@ -159,22 +160,20 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = 0;
|
||||
for (FieldWriter field : fields) {
|
||||
for (FieldWriter<?> field : fields) {
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
|
||||
private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
|
||||
// write vector values
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (float[] vector : fieldData.vectors) {
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeByteVectors(fieldData);
|
||||
case FLOAT32 -> writeFloat32Vectors(fieldData);
|
||||
}
|
||||
;
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
@ -194,7 +193,24 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
graph);
|
||||
}
|
||||
|
||||
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (Object v : fieldData.vectors) {
|
||||
buffer.asFloatBuffer().put((float[]) v);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
||||
for (Object v : fieldData.vectors) {
|
||||
BytesRef vector = (BytesRef) v;
|
||||
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
final int[] docIdOffsets = new int[sortMap.size()];
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
@ -221,15 +237,11 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
}
|
||||
|
||||
// write vector values
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (int ordinal : ordMap) {
|
||||
float[] vector = fieldData.vectors.get(ordinal);
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
}
|
||||
long vectorDataOffset =
|
||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
|
||||
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
|
||||
};
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
@ -249,6 +261,29 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
mockGraph);
|
||||
}
|
||||
|
||||
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||
for (int ordinal : ordMap) {
|
||||
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
||||
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
for (int ordinal : ordMap) {
|
||||
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
||||
vectorData.writeBytes(vector, 0, vector.length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
||||
// reconstruct graph substituting old ordinals with new ordinals
|
||||
private HnswGraph reconstructAndWriteGraph(
|
||||
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
|
||||
@ -354,7 +389,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
boolean success = false;
|
||||
try {
|
||||
// write the vector data to a temporary file
|
||||
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
|
||||
DocsWithFieldSet docsWithField =
|
||||
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
||||
@ -365,21 +401,22 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
|
||||
CodecUtil.retrieveChecksum(vectorDataInput);
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
// build the graph using the temporary vector data
|
||||
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||
// doesn't need to know docIds
|
||||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||
OffHeapVectorValues offHeapVectors =
|
||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput);
|
||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
|
||||
OnHeapHnswGraph graph = null;
|
||||
if (offHeapVectors.size() != 0) {
|
||||
// build graph
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
offHeapVectors,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
@ -451,6 +488,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
HnswGraph graph)
|
||||
throws IOException {
|
||||
meta.writeInt(field.number);
|
||||
meta.writeInt(field.getVectorEncoding().ordinal());
|
||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||
meta.writeVLong(vectorDataOffset);
|
||||
meta.writeVLong(vectorDataLength);
|
||||
@ -520,13 +558,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
|
||||
throws IOException {
|
||||
private static DocsWithFieldSet writeVectorData(
|
||||
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = vectors.binaryValue();
|
||||
assert binaryValue.length == vectors.dimension() * Float.BYTES;
|
||||
assert binaryValue.length == vectors.dimension() * scalarSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
@ -538,54 +576,69 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
IOUtils.close(meta, vectorData, vectorIndex);
|
||||
}
|
||||
|
||||
private static class FieldWriter extends KnnFieldVectorsWriter {
|
||||
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||
private final FieldInfo fieldInfo;
|
||||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<float[]> vectors;
|
||||
private final RAVectorValues raVectorValues;
|
||||
private final HnswGraphBuilder hnswGraphBuilder;
|
||||
private final List<T> vectors;
|
||||
private final RAVectorValues<T> raVectorValues;
|
||||
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
||||
|
||||
private int lastDocID = -1;
|
||||
private int node = 0;
|
||||
|
||||
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
throws IOException {
|
||||
int dim = fieldInfo.getVectorDimension();
|
||||
return switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||
@Override
|
||||
public BytesRef copyValue(BytesRef value) {
|
||||
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
@Override
|
||||
public float[] copyValue(float[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
throws IOException {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
raVectorValues = new RAVectorValues(vectors, dim);
|
||||
raVectorValues = new RAVectorValues<>(vectors, dim);
|
||||
hnswGraphBuilder =
|
||||
new HnswGraphBuilder(
|
||||
() -> raVectorValues,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
(HnswGraphBuilder<T>)
|
||||
HnswGraphBuilder.create(
|
||||
() -> raVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(infoStream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, float[] vectorValue) throws IOException {
|
||||
@SuppressWarnings("unchecked")
|
||||
public void addValue(int docID, Object value) throws IOException {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
if (vectorValue.length != dim) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to index a vector of dimension "
|
||||
+ vectorValue.length
|
||||
+ " but \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has dimension "
|
||||
+ dim);
|
||||
}
|
||||
T vectorValue = (T) value;
|
||||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
||||
vectors.add(copyValue(vectorValue));
|
||||
if (node > 0) {
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||
@ -608,16 +661,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ vectors.size() * vectors.get(0).length * Float.BYTES
|
||||
+ vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize
|
||||
+ hnswGraphBuilder.getGraph().ramBytesUsed();
|
||||
}
|
||||
}
|
||||
|
||||
private static class RAVectorValues implements RandomAccessVectorValues {
|
||||
private final List<float[]> vectors;
|
||||
private static class RAVectorValues<T> implements RandomAccessVectorValues {
|
||||
private final List<T> vectors;
|
||||
private final int dim;
|
||||
|
||||
RAVectorValues(List<float[]> vectors, int dim) {
|
||||
RAVectorValues(List<T> vectors, int dim) {
|
||||
this.vectors = vectors;
|
||||
this.dim = dim;
|
||||
}
|
||||
@ -634,12 +687,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return vectors.get(targetOrd);
|
||||
return (float[]) vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
return (BytesRef) vectors.get(targetOrd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -41,11 +41,11 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
protected final int byteSize;
|
||||
protected final float[] value;
|
||||
|
||||
OffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
||||
OffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
byteSize = Float.BYTES * dimension;
|
||||
this.byteSize = byteSize;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
value = new float[dimension];
|
||||
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
@ -93,10 +93,16 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
int byteSize =
|
||||
switch (fieldEntry.vectorEncoding) {
|
||||
case BYTE -> fieldEntry.dimension;
|
||||
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
|
||||
};
|
||||
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||
return new DenseOffHeapVectorValues(fieldEntry.dimension, fieldEntry.size, bytesSlice);
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice);
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,8 +112,8 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
||||
super(dimension, size, slice);
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -145,7 +151,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -167,10 +173,13 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
|
||||
|
||||
public SparseOffHeapVectorValues(
|
||||
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice)
|
||||
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
int byteSize)
|
||||
throws IOException {
|
||||
|
||||
super(fieldEntry.dimension, fieldEntry.size, slice);
|
||||
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
|
||||
this.fieldEntry = fieldEntry;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
||||
@ -218,7 +227,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -248,7 +257,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null);
|
||||
super(dimension, 0, null, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -144,7 +144,7 @@
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
||||
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
||||
* information about how the segment is sorted
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
|
||||
* <li>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This
|
||||
* contains metadata about the set of named fields used in the index.
|
||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
||||
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
||||
@ -240,7 +240,7 @@
|
||||
* systems that frequently run out of file handles.</td>
|
||||
* </tr>
|
||||
* <tr>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}</td>
|
||||
* <td>.fnm</td>
|
||||
* <td>Stores information about the fields</td>
|
||||
* </tr>
|
||||
|
@ -33,6 +33,7 @@ import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
@ -101,7 +102,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
KnnVectorsWriter writer = getInstance(fieldInfo);
|
||||
return writer.addField(fieldInfo);
|
||||
}
|
||||
@ -267,6 +268,17 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
} else {
|
||||
return knnVectorsReader.searchExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(fields.values());
|
||||
|
@ -24,6 +24,7 @@ import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexableFieldType;
|
||||
import org.apache.lucene.index.PointValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
||||
@ -44,6 +45,7 @@ public class FieldType implements IndexableFieldType {
|
||||
private int indexDimensionCount;
|
||||
private int dimensionNumBytes;
|
||||
private int vectorDimension;
|
||||
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
|
||||
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
private Map<String, String> attributes;
|
||||
|
||||
@ -62,6 +64,7 @@ public class FieldType implements IndexableFieldType {
|
||||
this.indexDimensionCount = ref.pointIndexDimensionCount();
|
||||
this.dimensionNumBytes = ref.pointNumBytes();
|
||||
this.vectorDimension = ref.vectorDimension();
|
||||
this.vectorEncoding = ref.vectorEncoding();
|
||||
this.vectorSimilarityFunction = ref.vectorSimilarityFunction();
|
||||
if (ref.getAttributes() != null) {
|
||||
this.attributes = new HashMap<>(ref.getAttributes());
|
||||
@ -371,8 +374,8 @@ public class FieldType implements IndexableFieldType {
|
||||
}
|
||||
|
||||
/** Enable vector indexing, with the specified number of dimensions and distance function. */
|
||||
public void setVectorDimensionsAndSimilarityFunction(
|
||||
int numDimensions, VectorSimilarityFunction distFunc) {
|
||||
public void setVectorAttributes(
|
||||
int numDimensions, VectorEncoding encoding, VectorSimilarityFunction similarity) {
|
||||
checkIfFrozen();
|
||||
if (numDimensions <= 0) {
|
||||
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
|
||||
@ -385,7 +388,8 @@ public class FieldType implements IndexableFieldType {
|
||||
+ numDimensions);
|
||||
}
|
||||
this.vectorDimension = numDimensions;
|
||||
this.vectorSimilarityFunction = Objects.requireNonNull(distFunc);
|
||||
this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
|
||||
this.vectorEncoding = Objects.requireNonNull(encoding);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -393,6 +397,11 @@ public class FieldType implements IndexableFieldType {
|
||||
return vectorDimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorEncoding vectorEncoding() {
|
||||
return vectorEncoding;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorSimilarityFunction vectorSimilarityFunction() {
|
||||
return vectorSimilarityFunction;
|
||||
|
@ -17,8 +17,10 @@
|
||||
|
||||
package org.apache.lucene.document;
|
||||
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
@ -39,7 +41,18 @@ public class KnnVectorField extends Field {
|
||||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
int dimension = v.length;
|
||||
return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
|
||||
}
|
||||
|
||||
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
||||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
|
||||
}
|
||||
|
||||
private static FieldType createType(
|
||||
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
||||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
@ -51,13 +64,13 @@ public class KnnVectorField extends Field {
|
||||
throw new IllegalArgumentException("similarity function must not be null");
|
||||
}
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
||||
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience method for creating a vector field type.
|
||||
* A convenience method for creating a vector field type with the default FLOAT32 encoding.
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
@ -65,8 +78,21 @@ public class KnnVectorField extends Field {
|
||||
*/
|
||||
public static FieldType createFieldType(
|
||||
int dimension, VectorSimilarityFunction similarityFunction) {
|
||||
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience method for creating a vector field type.
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param vectorEncoding the encoding of the scalar values
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||
*/
|
||||
public static FieldType createFieldType(
|
||||
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
||||
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
@ -74,8 +100,8 @@ public class KnnVectorField extends Field {
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||
* some strategies (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to be
|
||||
* unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
|
||||
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||
* be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
@ -88,6 +114,23 @@ public class KnnVectorField extends Field {
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||
* be constant-length.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||
super(name, createType(vector, similarityFunction));
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
||||
* single-valued: each document has either one value or no value. Vectors of a single field share
|
||||
@ -117,6 +160,21 @@ public class KnnVectorField extends Field {
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param fieldType field type
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/** Return the vector value of this field */
|
||||
public float[] vectorValue() {
|
||||
return (float[]) fieldsData;
|
||||
|
@ -45,7 +45,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
protected BufferingKnnVectorsWriter() {}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
public KnnFieldVectorsWriter<float[]> addField(FieldInfo fieldInfo) throws IOException {
|
||||
FieldWriter newField = new FieldWriter(fieldInfo);
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
@ -88,6 +88,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
|
||||
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
||||
@ -122,6 +128,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
@ -137,7 +149,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
protected abstract void writeField(
|
||||
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
|
||||
|
||||
private static class FieldWriter extends KnnFieldVectorsWriter {
|
||||
private static class FieldWriter extends KnnFieldVectorsWriter<float[]> {
|
||||
private final FieldInfo fieldInfo;
|
||||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
@ -153,35 +165,45 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, float[] vectorValue) {
|
||||
public void addValue(int docID, Object value) {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
if (vectorValue.length != dim) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to index a vector of dimension "
|
||||
+ vectorValue.length
|
||||
+ " but \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has dimension "
|
||||
+ dim);
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
float[] vectorValue =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> (float[]) value;
|
||||
case BYTE -> bytesToFloats((BytesRef) value);
|
||||
};
|
||||
docsWithField.add(docID);
|
||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
||||
vectors.add(copyValue(vectorValue));
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
private float[] bytesToFloats(BytesRef b) {
|
||||
// This is used only by SimpleTextKnnVectorsWriter
|
||||
float[] floats = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
floats[i] = b.bytes[i + b.offset];
|
||||
}
|
||||
return floats;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] copyValue(float[] vectorValue) {
|
||||
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
if (vectors.size() == 0) return 0;
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ vectors.size() * vectors.get(0).length * Float.BYTES;
|
||||
+ vectors.size() * dim * Float.BYTES;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
||||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
@ -235,6 +236,19 @@ public abstract class CodecReader extends LeafReader {
|
||||
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// Field does not exist or does not index vectors
|
||||
return null;
|
||||
}
|
||||
|
||||
return getVectorReader().searchExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() throws IOException {}
|
||||
|
||||
|
@ -18,6 +18,7 @@
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void checkIntegrity() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
|
@ -56,6 +56,7 @@ public final class FieldInfo {
|
||||
|
||||
// if it is a positive value, it means this field indexes vectors
|
||||
private final int vectorDimension;
|
||||
private final VectorEncoding vectorEncoding;
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
||||
// whether this field is used as the soft-deletes field
|
||||
@ -80,6 +81,7 @@ public final class FieldInfo {
|
||||
int pointIndexDimensionCount,
|
||||
int pointNumBytes,
|
||||
int vectorDimension,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
boolean softDeletesField) {
|
||||
this.name = Objects.requireNonNull(name);
|
||||
@ -105,6 +107,7 @@ public final class FieldInfo {
|
||||
this.pointIndexDimensionCount = pointIndexDimensionCount;
|
||||
this.pointNumBytes = pointNumBytes;
|
||||
this.vectorDimension = vectorDimension;
|
||||
this.vectorEncoding = vectorEncoding;
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
this.softDeletesField = softDeletesField;
|
||||
this.checkConsistency();
|
||||
@ -229,8 +232,10 @@ public final class FieldInfo {
|
||||
verifySameVectorOptions(
|
||||
fieldName,
|
||||
this.vectorDimension,
|
||||
this.vectorEncoding,
|
||||
this.vectorSimilarityFunction,
|
||||
o.vectorDimension,
|
||||
o.vectorEncoding,
|
||||
o.vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
@ -347,19 +352,25 @@ public final class FieldInfo {
|
||||
static void verifySameVectorOptions(
|
||||
String fieldName,
|
||||
int vd1,
|
||||
VectorEncoding ve1,
|
||||
VectorSimilarityFunction vsf1,
|
||||
int vd2,
|
||||
VectorEncoding ve2,
|
||||
VectorSimilarityFunction vsf2) {
|
||||
if (vd1 != vd2 || vsf1 != vsf2) {
|
||||
if (vd1 != vd2 || vsf1 != vsf2 || ve1 != ve2) {
|
||||
throw new IllegalArgumentException(
|
||||
"cannot change field \""
|
||||
+ fieldName
|
||||
+ "\" from vector dimension="
|
||||
+ vd1
|
||||
+ ", vector encoding="
|
||||
+ ve1
|
||||
+ ", vector similarity function="
|
||||
+ vsf1
|
||||
+ " to inconsistent vector dimension="
|
||||
+ vd2
|
||||
+ ", vector encoding="
|
||||
+ ve2
|
||||
+ ", vector similarity function="
|
||||
+ vsf2);
|
||||
}
|
||||
@ -470,6 +481,11 @@ public final class FieldInfo {
|
||||
return vectorDimension;
|
||||
}
|
||||
|
||||
/** Returns the number of dimensions of the vector value */
|
||||
public VectorEncoding getVectorEncoding() {
|
||||
return vectorEncoding;
|
||||
}
|
||||
|
||||
/** Returns {@link VectorSimilarityFunction} for the field */
|
||||
public VectorSimilarityFunction getVectorSimilarityFunction() {
|
||||
return vectorSimilarityFunction;
|
||||
|
@ -308,10 +308,15 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
||||
|
||||
static final class FieldVectorProperties {
|
||||
final int numDimensions;
|
||||
final VectorEncoding vectorEncoding;
|
||||
final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
FieldVectorProperties(int numDimensions, VectorSimilarityFunction similarityFunction) {
|
||||
FieldVectorProperties(
|
||||
int numDimensions,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
this.numDimensions = numDimensions;
|
||||
this.vectorEncoding = vectorEncoding;
|
||||
this.similarityFunction = similarityFunction;
|
||||
}
|
||||
}
|
||||
@ -401,7 +406,8 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
||||
fi.getPointNumBytes()));
|
||||
vectorProps.put(
|
||||
fieldName,
|
||||
new FieldVectorProperties(fi.getVectorDimension(), fi.getVectorSimilarityFunction()));
|
||||
new FieldVectorProperties(
|
||||
fi.getVectorDimension(), fi.getVectorEncoding(), fi.getVectorSimilarityFunction()));
|
||||
}
|
||||
return fieldNumber.intValue();
|
||||
}
|
||||
@ -459,8 +465,10 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
||||
verifySameVectorOptions(
|
||||
fieldName,
|
||||
props.numDimensions,
|
||||
props.vectorEncoding,
|
||||
props.similarityFunction,
|
||||
fi.getVectorDimension(),
|
||||
fi.getVectorEncoding(),
|
||||
fi.getVectorSimilarityFunction());
|
||||
}
|
||||
|
||||
@ -503,6 +511,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
|
||||
addOrGet(fi);
|
||||
@ -584,6 +593,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
isSoftDeletesField);
|
||||
}
|
||||
@ -698,6 +708,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
||||
fi.getPointIndexDimensionCount(),
|
||||
fi.getPointNumBytes(),
|
||||
fi.getVectorDimension(),
|
||||
fi.getVectorEncoding(),
|
||||
fi.getVectorSimilarityFunction(),
|
||||
fi.isSoftDeletesField());
|
||||
byName.put(fiNew.getName(), fiNew);
|
||||
|
@ -18,6 +18,7 @@ package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Iterator;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.AttributeSource;
|
||||
import org.apache.lucene.util.Bits;
|
||||
@ -357,6 +358,12 @@ public abstract class FilterLeafReader extends LeafReader {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Fields getTermVectors(int docID) throws IOException {
|
||||
ensureOpen();
|
||||
|
@ -101,6 +101,9 @@ public interface IndexableFieldType {
|
||||
/** The number of dimensions of the field's vector value */
|
||||
int vectorDimension();
|
||||
|
||||
/** The {@link VectorEncoding} of the field's vector value */
|
||||
VectorEncoding vectorEncoding();
|
||||
|
||||
/** The {@link VectorSimilarityFunction} of the field's vector value */
|
||||
VectorSimilarityFunction vectorSimilarityFunction();
|
||||
|
||||
|
@ -628,6 +628,7 @@ final class IndexingChain implements Accountable {
|
||||
s.pointIndexDimensionCount,
|
||||
s.pointNumBytes,
|
||||
s.vectorDimension,
|
||||
s.vectorEncoding,
|
||||
s.vectorSimilarityFunction,
|
||||
pf.fieldName.equals(fieldInfos.getSoftDeletesFieldName())));
|
||||
pf.setFieldInfo(fi);
|
||||
@ -712,7 +713,11 @@ final class IndexingChain implements Accountable {
|
||||
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
||||
}
|
||||
if (fieldType.vectorDimension() != 0) {
|
||||
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
|
||||
switch (fieldType.vectorEncoding()) {
|
||||
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
|
||||
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
|
||||
docID, ((KnnVectorField) field).vectorValue());
|
||||
}
|
||||
}
|
||||
return indexedField;
|
||||
}
|
||||
@ -776,7 +781,10 @@ final class IndexingChain implements Accountable {
|
||||
fieldType.pointNumBytes());
|
||||
}
|
||||
if (fieldType.vectorDimension() != 0) {
|
||||
schema.setVectors(fieldType.vectorSimilarityFunction(), fieldType.vectorDimension());
|
||||
schema.setVectors(
|
||||
fieldType.vectorEncoding(),
|
||||
fieldType.vectorSimilarityFunction(),
|
||||
fieldType.vectorDimension());
|
||||
}
|
||||
if (fieldType.getAttributes() != null && fieldType.getAttributes().isEmpty() == false) {
|
||||
schema.updateAttributes(fieldType.getAttributes());
|
||||
@ -988,7 +996,7 @@ final class IndexingChain implements Accountable {
|
||||
PointValuesWriter pointValuesWriter;
|
||||
|
||||
// Non-null if this field had vectors in this segment
|
||||
KnnFieldVectorsWriter knnFieldVectorsWriter;
|
||||
KnnFieldVectorsWriter<?> knnFieldVectorsWriter;
|
||||
|
||||
/** We use this to know when a PerField is seen for the first time in the current document. */
|
||||
long fieldGen = -1;
|
||||
@ -1281,6 +1289,7 @@ final class IndexingChain implements Accountable {
|
||||
private int pointIndexDimensionCount = 0;
|
||||
private int pointNumBytes = 0;
|
||||
private int vectorDimension = 0;
|
||||
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
|
||||
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
|
||||
private static String errMsg =
|
||||
@ -1361,11 +1370,14 @@ final class IndexingChain implements Accountable {
|
||||
}
|
||||
}
|
||||
|
||||
void setVectors(VectorSimilarityFunction similarityFunction, int dimension) {
|
||||
void setVectors(
|
||||
VectorEncoding encoding, VectorSimilarityFunction similarityFunction, int dimension) {
|
||||
if (vectorDimension == 0) {
|
||||
this.vectorDimension = dimension;
|
||||
this.vectorEncoding = encoding;
|
||||
this.vectorSimilarityFunction = similarityFunction;
|
||||
this.vectorDimension = dimension;
|
||||
} else {
|
||||
assertSame("vector encoding", vectorEncoding, encoding);
|
||||
assertSame("vector similarity function", vectorSimilarityFunction, similarityFunction);
|
||||
assertSame("vector dimension", vectorDimension, dimension);
|
||||
}
|
||||
@ -1381,6 +1393,7 @@ final class IndexingChain implements Accountable {
|
||||
pointIndexDimensionCount = 0;
|
||||
pointNumBytes = 0;
|
||||
vectorDimension = 0;
|
||||
vectorEncoding = VectorEncoding.FLOAT32;
|
||||
vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
}
|
||||
|
||||
@ -1391,6 +1404,7 @@ final class IndexingChain implements Accountable {
|
||||
assertSame("doc values type", fi.getDocValuesType(), docValuesType);
|
||||
assertSame(
|
||||
"vector similarity function", fi.getVectorSimilarityFunction(), vectorSimilarityFunction);
|
||||
assertSame("vector encoding", fi.getVectorEncoding(), vectorEncoding);
|
||||
assertSame("vector dimension", fi.getVectorDimension(), vectorDimension);
|
||||
assertSame("point dimension", fi.getPointDimensionCount(), pointDimensionCount);
|
||||
assertSame(
|
||||
|
@ -17,6 +17,7 @@
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
@ -235,6 +236,30 @@ public abstract class LeafReader extends IndexReader {
|
||||
public abstract TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||
* larger score corresponds to a higher ranking.
|
||||
*
|
||||
* <p>The search is exact, meaning the results are guaranteed to be the true k closest neighbors.
|
||||
* This typically requires an exhaustive scan of all candidate documents.
|
||||
*
|
||||
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
|
||||
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
|
||||
* TotalHits} contains the number of documents visited during the search.
|
||||
*
|
||||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the {@link FieldInfos} describing all fields in this reader.
|
||||
*
|
||||
|
@ -26,6 +26,7 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.SortedMap;
|
||||
import java.util.TreeMap;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
@ -403,6 +404,16 @@ public class ParallelLeafReader extends LeafReader {
|
||||
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String fieldName, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
return reader == null
|
||||
? null
|
||||
: reader.searchNearestVectorsExhaustively(fieldName, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
ensureOpen();
|
||||
|
@ -722,6 +722,7 @@ final class ReadersAndUpdates {
|
||||
fi.getPointIndexDimensionCount(),
|
||||
fi.getPointNumBytes(),
|
||||
fi.getVectorDimension(),
|
||||
fi.getVectorEncoding(),
|
||||
fi.getVectorSimilarityFunction(),
|
||||
fi.isSoftDeletesField());
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
||||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
@ -172,6 +173,12 @@ public final class SlowCodecReaderWrapper {
|
||||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
return reader.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {
|
||||
// We already checkIntegrity the entire reader up front
|
||||
|
@ -31,6 +31,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
||||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
@ -389,6 +390,12 @@ public final class SortingCodecReader extends FilterCodecReader {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
delegate.close();
|
||||
|
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.index;
|
||||
|
||||
/** The numeric datatype of the vector values. */
|
||||
public enum VectorEncoding {
|
||||
|
||||
/**
|
||||
* Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity.
|
||||
* NOTE: this can enable significant storage savings and faster searches, at the cost of some
|
||||
* possible loss of precision. In order to use it, all vectors must be of the same norm, as
|
||||
* measured by the sum of the squares of the scalar values, and those values must be in the range
|
||||
* [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can
|
||||
* result in errors or poor search results.
|
||||
*/
|
||||
BYTE(1),
|
||||
|
||||
/** Encodes vector using 32 bits of precision per sample in IEEE floating point format. */
|
||||
FLOAT32(4);
|
||||
|
||||
/**
|
||||
* The number of bytes required to encode a scalar in this format. A vector will require dimension
|
||||
* * byteSize.
|
||||
*/
|
||||
public final int byteSize;
|
||||
|
||||
VectorEncoding(int byteSize) {
|
||||
this.byteSize = byteSize;
|
||||
}
|
||||
}
|
@ -18,6 +18,8 @@ package org.apache.lucene.index;
|
||||
|
||||
import static org.apache.lucene.util.VectorUtil.*;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Vector similarity function; used in search to return top K most similar vectors to a target
|
||||
* vector. This is a label describing the method used during indexing and searching of the vectors
|
||||
@ -31,6 +33,11 @@ public enum VectorSimilarityFunction {
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return 1 / (1 + squareDistance(v1, v2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float compare(BytesRef v1, BytesRef v2) {
|
||||
return 1 / (1 + squareDistance(v1, v2));
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
@ -44,6 +51,11 @@ public enum VectorSimilarityFunction {
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return (1 + dotProduct(v1, v2)) / 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float compare(BytesRef v1, BytesRef v2) {
|
||||
return dotProductScore(v1, v2);
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
@ -57,6 +69,11 @@ public enum VectorSimilarityFunction {
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return (1 + cosine(v1, v2)) / 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float compare(BytesRef v1, BytesRef v2) {
|
||||
return (1 + cosine(v1, v2)) / 2;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@ -68,4 +85,15 @@ public enum VectorSimilarityFunction {
|
||||
* @return the value of the similarity function applied to the two vectors
|
||||
*/
|
||||
public abstract float compare(float[] v1, float[] v2);
|
||||
|
||||
/**
|
||||
* Calculates a similarity score between the two vectors with a specified function. Higher
|
||||
* similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs
|
||||
* determine the vector data that is compared. Each (signed) byte represents a vector dimension.
|
||||
*
|
||||
* @param v1 a vector
|
||||
* @param v2 another vector, of the same dimension
|
||||
* @return the value of the similarity function applied to the two vectors
|
||||
*/
|
||||
public abstract float compare(BytesRef v1, BytesRef v2);
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ class VectorValuesConsumer {
|
||||
}
|
||||
}
|
||||
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
initKnnVectorsWriter(fieldInfo.name);
|
||||
return writer.addField(fieldInfo);
|
||||
}
|
||||
|
@ -24,11 +24,8 @@ import java.util.Comparator;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
@ -133,22 +130,21 @@ public class KnnVectorQuery extends Query {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
|
||||
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||
|
||||
if (filterIterator.cost() <= k) {
|
||||
if (acceptDocs.cardinality() <= k) {
|
||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||
// must always visit at least k documents
|
||||
return exactSearch(ctx, filterIterator);
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
|
||||
}
|
||||
|
||||
// Perform the approximate kNN search
|
||||
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost());
|
||||
TopDocs results = approximateSearch(ctx, acceptDocs, acceptDocs.cardinality());
|
||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||
return results;
|
||||
} else {
|
||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||
return exactSearch(ctx, filterIterator);
|
||||
return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,45 +174,9 @@ public class KnnVectorQuery extends Query {
|
||||
}
|
||||
|
||||
// We allow this to be overridden so that tests can check what search strategy is used
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptDocs)
|
||||
throws IOException {
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
// The field does not exist or does not index vectors
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
|
||||
VectorValues vectorValues = context.reader().getVectorValues(field);
|
||||
|
||||
HitQueue queue = new HitQueue(k, true);
|
||||
ScoreDoc topDoc = queue.top();
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int vectorDoc = vectorValues.advance(doc);
|
||||
assert vectorDoc == doc;
|
||||
float[] vector = vectorValues.vectorValue();
|
||||
|
||||
float score = similarityFunction.compare(vector, target);
|
||||
if (score >= topDoc.score) {
|
||||
topDoc.score = score;
|
||||
topDoc.doc = doc;
|
||||
topDoc = queue.updateTop();
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any remaining sentinel values
|
||||
while (queue.size() > 0 && queue.top().score < 0) {
|
||||
queue.pop();
|
||||
}
|
||||
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = queue.pop();
|
||||
}
|
||||
|
||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
||||
return new TopDocs(totalHits, topScoreDocs);
|
||||
return context.reader().searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
|
@ -121,6 +121,24 @@ public final class VectorUtil {
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
}
|
||||
|
||||
/** Returns the cosine similarity between the two vectors. */
|
||||
public static float cosine(BytesRef a, BytesRef b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
int norm2 = 0;
|
||||
int aOffset = a.offset, bOffset = b.offset;
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
byte elem1 = a.bytes[aOffset++];
|
||||
byte elem2 = b.bytes[bOffset++];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sum of squared differences of the two vectors.
|
||||
*
|
||||
@ -135,7 +153,7 @@ public final class VectorUtil {
|
||||
int dim = v1.length;
|
||||
int i;
|
||||
for (i = 0; i + 8 <= dim; i += 8) {
|
||||
squareSum += squareDistanceUnrolled8(v1, v2, i);
|
||||
squareSum += squareDistanceUnrolled(v1, v2, i);
|
||||
}
|
||||
for (; i < dim; i++) {
|
||||
float diff = v1[i] - v2[i];
|
||||
@ -144,7 +162,7 @@ public final class VectorUtil {
|
||||
return squareSum;
|
||||
}
|
||||
|
||||
private static float squareDistanceUnrolled8(float[] v1, float[] v2, int index) {
|
||||
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
|
||||
float diff0 = v1[index + 0] - v2[index + 0];
|
||||
float diff1 = v1[index + 1] - v2[index + 1];
|
||||
float diff2 = v1[index + 2] - v2[index + 2];
|
||||
@ -163,6 +181,18 @@ public final class VectorUtil {
|
||||
+ diff7 * diff7;
|
||||
}
|
||||
|
||||
/** Returns the sum of squared differences of the two vectors. */
|
||||
public static float squareDistance(BytesRef a, BytesRef b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int squareSum = 0;
|
||||
int aOffset = a.offset, bOffset = b.offset;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
int diff = a.bytes[aOffset++] - b.bytes[bOffset++];
|
||||
squareSum += diff * diff;
|
||||
}
|
||||
return squareSum;
|
||||
}
|
||||
|
||||
/**
|
||||
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
|
||||
* thrown for zero vectors.
|
||||
@ -213,4 +243,48 @@ public final class VectorUtil {
|
||||
u[i] += v[i];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dot product computed over signed bytes.
|
||||
*
|
||||
* @param a bytes containing a vector
|
||||
* @param b bytes containing another vector, of the same dimension
|
||||
* @return the value of the dot product of the two vectors
|
||||
*/
|
||||
public static float dotProduct(BytesRef a, BytesRef b) {
|
||||
assert a.length == b.length;
|
||||
int total = 0;
|
||||
int aOffset = a.offset, bOffset = b.offset;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
total += a.bytes[aOffset++] * b.bytes[bOffset++];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dot product score computed over signed bytes, scaled to be in [0, 1].
|
||||
*
|
||||
* @param a bytes containing a vector
|
||||
* @param b bytes containing another vector, of the same dimension
|
||||
* @return the value of the similarity function applied to the two vectors
|
||||
*/
|
||||
public static float dotProductScore(BytesRef a, BytesRef b) {
|
||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||
return (1 + dotProduct(a, b)) / (float) (a.length * (1 << 15));
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a floating point vector to an array of bytes using casting; the vector values should be
|
||||
* in [-128,127]
|
||||
*
|
||||
* @param vector a vector
|
||||
* @return a new BytesRef containing the vector's values cast to byte.
|
||||
*/
|
||||
public static BytesRef toBytesRef(float[] vector) {
|
||||
BytesRef b = new BytesRef(new byte[vector.length]);
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
b.bytes[i] = (byte) vector[i];
|
||||
}
|
||||
return b;
|
||||
}
|
||||
}
|
||||
|
@ -25,15 +25,19 @@ import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
/**
|
||||
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
||||
* hyperparameters.
|
||||
*
|
||||
* @param <T> the type of vector
|
||||
*/
|
||||
public final class HnswGraphBuilder {
|
||||
public final class HnswGraphBuilder<T> {
|
||||
|
||||
/** Default random seed for level generation * */
|
||||
private static final long DEFAULT_RAND_SEED = 42;
|
||||
@ -49,9 +53,10 @@ public final class HnswGraphBuilder {
|
||||
private final NeighborArray scratch;
|
||||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final VectorEncoding vectorEncoding;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
private final HnswGraphSearcher<T> graphSearcher;
|
||||
|
||||
final OnHeapHnswGraph hnsw;
|
||||
|
||||
@ -59,7 +64,18 @@ public final class HnswGraphBuilder {
|
||||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private RandomAccessVectorValues buildVectors;
|
||||
private final RandomAccessVectorValues buildVectors;
|
||||
|
||||
public static HnswGraphBuilder<?> create(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed)
|
||||
throws IOException {
|
||||
return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||
@ -73,8 +89,9 @@ public final class HnswGraphBuilder {
|
||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||
* to ensure repeatable construction.
|
||||
*/
|
||||
public HnswGraphBuilder(
|
||||
private HnswGraphBuilder(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
int beamWidth,
|
||||
@ -82,6 +99,7 @@ public final class HnswGraphBuilder {
|
||||
throws IOException {
|
||||
vectorValues = vectors.randomAccess();
|
||||
buildVectors = vectors.randomAccess();
|
||||
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
|
||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||
if (M <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
@ -97,7 +115,8 @@ public final class HnswGraphBuilder {
|
||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
|
||||
this.graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
new NeighborQueue(beamWidth, true),
|
||||
new FixedBitSet(vectorValues.size()));
|
||||
@ -110,7 +129,7 @@ public final class HnswGraphBuilder {
|
||||
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* returned values.
|
||||
*
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
@ -121,15 +140,19 @@ public final class HnswGraphBuilder {
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
|
||||
}
|
||||
addVectors(vectors);
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
private void addVectors(RandomAccessVectorValues vectors) throws IOException {
|
||||
long start = System.nanoTime(), t = start;
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
for (int node = 1; node < vectors.size(); node++) {
|
||||
addGraphNode(node, vectors.vectorValue(node));
|
||||
addGraphNode(node, vectors);
|
||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
t = printGraphBuildStatus(node, start, t);
|
||||
}
|
||||
}
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
/** Set info-stream to output debugging information * */
|
||||
@ -142,7 +165,7 @@ public final class HnswGraphBuilder {
|
||||
}
|
||||
|
||||
/** Inserts a doc with vector value to the graph */
|
||||
public void addGraphNode(int node, float[] value) throws IOException {
|
||||
public void addGraphNode(int node, T value) throws IOException {
|
||||
NeighborQueue candidates;
|
||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||
int curMaxLevel = hnsw.numLevels() - 1;
|
||||
@ -167,6 +190,18 @@ public final class HnswGraphBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
|
||||
addGraphNode(node, getValue(node, values));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> (T) values.binaryValue(node);
|
||||
case FLOAT32 -> (T) values.vectorValue(node);
|
||||
};
|
||||
}
|
||||
|
||||
private long printGraphBuildStatus(int node, long start, long t) {
|
||||
long now = System.nanoTime();
|
||||
infoStream.message(
|
||||
@ -215,7 +250,7 @@ public final class HnswGraphBuilder {
|
||||
int cNode = candidates.node[i];
|
||||
float cScore = candidates.score[i];
|
||||
assert cNode < hnsw.size();
|
||||
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
|
||||
if (diversityCheck(cNode, cScore, neighbors)) {
|
||||
neighbors.add(cNode, cScore);
|
||||
}
|
||||
}
|
||||
@ -237,19 +272,38 @@ public final class HnswGraphBuilder {
|
||||
* @param score the score of the new candidate and node n, to be compared with scores of the
|
||||
* candidate and n's neighbors
|
||||
* @param neighbors the neighbors selected so far
|
||||
* @param vectorValues source of values used for making comparisons between candidate and existing
|
||||
* neighbors
|
||||
* @return whether the candidate is diverse given the existing neighbors
|
||||
*/
|
||||
private boolean diversityCheck(
|
||||
float[] candidate,
|
||||
float score,
|
||||
NeighborArray neighbors,
|
||||
RandomAccessVectorValues vectorValues)
|
||||
private boolean diversityCheck(int candidate, float score, NeighborArray neighbors)
|
||||
throws IOException {
|
||||
return isDiverse(candidate, neighbors, score);
|
||||
}
|
||||
|
||||
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isDiverse(vectorValues.binaryValue(candidate), neighbors, score);
|
||||
case FLOAT32 -> isDiverse(vectorValues.vectorValue(candidate), neighbors, score);
|
||||
};
|
||||
}
|
||||
|
||||
private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
|
||||
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
@ -262,24 +316,52 @@ public final class HnswGraphBuilder {
|
||||
* neighbours
|
||||
*/
|
||||
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
||||
float minAcceptedSimilarity;
|
||||
for (int i = neighbors.size() - 1; i > 0; i--) {
|
||||
int cNode = neighbors.node[i];
|
||||
float[] cVector = vectorValues.vectorValue(cNode);
|
||||
minAcceptedSimilarity = neighbors.score[i];
|
||||
// check the candidate against its better-scoring neighbors
|
||||
for (int j = i - 1; j >= 0; j--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return i;
|
||||
}
|
||||
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return neighbors.size() - 1;
|
||||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isWorstNonDiverse(
|
||||
candidate, vectorValues.binaryValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
case FLOAT32 -> isWorstNonDiverse(
|
||||
candidate, vectorValues.vectorValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
};
|
||||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
|
||||
throws IOException {
|
||||
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean isWorstNonDiverse(
|
||||
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
|
||||
throws IOException {
|
||||
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
||||
double randDouble;
|
||||
do {
|
||||
|
@ -18,21 +18,28 @@
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
|
||||
/**
|
||||
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
|
||||
* search algorithm, see {@link HnswGraph}.
|
||||
*
|
||||
* @param <T> the type of query vector
|
||||
*/
|
||||
public final class HnswGraphSearcher {
|
||||
public class HnswGraphSearcher<T> {
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final VectorEncoding vectorEncoding;
|
||||
|
||||
/**
|
||||
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
|
||||
* to allocate, so they're cleared and reused across calls.
|
||||
@ -49,7 +56,11 @@ public final class HnswGraphSearcher {
|
||||
* @param visited bit set that will track nodes that have already been visited
|
||||
*/
|
||||
public HnswGraphSearcher(
|
||||
VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) {
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
NeighborQueue candidates,
|
||||
BitSet visited) {
|
||||
this.vectorEncoding = vectorEncoding;
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.candidates = candidates;
|
||||
this.visited = visited;
|
||||
@ -73,13 +84,68 @@ public final class HnswGraphSearcher {
|
||||
float[] query,
|
||||
int topK,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds,
|
||||
int visitedLimit)
|
||||
throws IOException {
|
||||
HnswGraphSearcher graphSearcher =
|
||||
new HnswGraphSearcher(
|
||||
if (query.length != vectors.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: "
|
||||
+ query.length
|
||||
+ " differs from field dimension: "
|
||||
+ vectors.dimension());
|
||||
}
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
return search(
|
||||
toBytesRef(query),
|
||||
topK,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
graph,
|
||||
acceptOrds,
|
||||
visitedLimit);
|
||||
}
|
||||
HnswGraphSearcher<float[]> graphSearcher =
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
new NeighborQueue(topK, true),
|
||||
new SparseFixedBitSet(vectors.size()));
|
||||
NeighborQueue results;
|
||||
int[] eps = new int[] {graph.entryNode()};
|
||||
int numVisited = 0;
|
||||
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
|
||||
numVisited += results.visitedCount();
|
||||
visitedLimit -= results.visitedCount();
|
||||
if (results.incomplete()) {
|
||||
results.setVisitedCount(numVisited);
|
||||
return results;
|
||||
}
|
||||
eps[0] = results.pop();
|
||||
}
|
||||
results =
|
||||
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
|
||||
results.setVisitedCount(results.visitedCount() + numVisited);
|
||||
return results;
|
||||
}
|
||||
|
||||
private static NeighborQueue search(
|
||||
BytesRef query,
|
||||
int topK,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds,
|
||||
int visitedLimit)
|
||||
throws IOException {
|
||||
HnswGraphSearcher<BytesRef> graphSearcher =
|
||||
new HnswGraphSearcher<>(
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
new NeighborQueue(topK, true),
|
||||
new SparseFixedBitSet(vectors.size()));
|
||||
@ -119,7 +185,8 @@ public final class HnswGraphSearcher {
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
public NeighborQueue searchLevel(
|
||||
float[] query,
|
||||
// Note: this is only public because Lucene91HnswGraphBuilder needs it
|
||||
T query,
|
||||
int topK,
|
||||
int level,
|
||||
final int[] eps,
|
||||
@ -130,7 +197,7 @@ public final class HnswGraphSearcher {
|
||||
}
|
||||
|
||||
private NeighborQueue searchLevel(
|
||||
float[] query,
|
||||
T query,
|
||||
int topK,
|
||||
int level,
|
||||
final int[] eps,
|
||||
@ -150,7 +217,7 @@ public final class HnswGraphSearcher {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
|
||||
float score = compare(query, vectors, ep);
|
||||
numVisited++;
|
||||
candidates.add(ep, score);
|
||||
if (acceptOrds == null || acceptOrds.get(ep)) {
|
||||
@ -185,7 +252,7 @@ public final class HnswGraphSearcher {
|
||||
results.markIncomplete();
|
||||
break;
|
||||
}
|
||||
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
float friendSimilarity = compare(query, vectors, friendOrd);
|
||||
numVisited++;
|
||||
if (friendSimilarity >= minAcceptedSimilarity) {
|
||||
candidates.add(friendOrd, friendSimilarity);
|
||||
@ -204,6 +271,14 @@ public final class HnswGraphSearcher {
|
||||
return results;
|
||||
}
|
||||
|
||||
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
|
||||
} else {
|
||||
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
|
||||
}
|
||||
}
|
||||
|
||||
private void prepareScratchState(int capacity) {
|
||||
candidates.clear();
|
||||
if (visited.length() < capacity) {
|
||||
|
@ -178,7 +178,7 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
return new KnnVectorsWriter() {
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
fieldsWritten.add(fieldInfo.name);
|
||||
return writer.addField(fieldInfo);
|
||||
}
|
||||
|
@ -112,6 +112,7 @@ public class TestCodecs extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false));
|
||||
}
|
||||
|
@ -260,6 +260,7 @@ public class TestFieldInfos extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false));
|
||||
}
|
||||
@ -279,6 +280,7 @@ public class TestFieldInfos extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false));
|
||||
assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
|
||||
@ -300,6 +302,7 @@ public class TestFieldInfos extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false));
|
||||
assertEquals("Field numbers should reset after clear()", 0, idx);
|
||||
|
@ -64,6 +64,7 @@ public class TestFieldsReader extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
field.name().equals(softDeletesFieldName)));
|
||||
}
|
||||
|
@ -113,6 +113,11 @@ public class TestIndexableField extends LuceneTestCase {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorEncoding vectorEncoding() {
|
||||
return VectorEncoding.FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorSimilarityFunction vectorSimilarityFunction() {
|
||||
return VectorSimilarityFunction.EUCLIDEAN;
|
||||
|
@ -67,6 +67,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
private static int M = Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
|
||||
private Codec codec;
|
||||
private Codec float32Codec;
|
||||
private VectorEncoding vectorEncoding;
|
||||
private VectorSimilarityFunction similarityFunction;
|
||||
|
||||
@Before
|
||||
@ -86,6 +88,31 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
|
||||
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
|
||||
similarityFunction = VectorSimilarityFunction.values()[similarity];
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
|
||||
codec =
|
||||
new Lucene94Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
|
||||
if (vectorEncoding == VectorEncoding.FLOAT32) {
|
||||
float32Codec = codec;
|
||||
} else {
|
||||
float32Codec =
|
||||
new Lucene94Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private VectorEncoding randomVectorEncoding() {
|
||||
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
}
|
||||
|
||||
@After
|
||||
@ -102,10 +129,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
float[][] values = new float[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextBoolean()) {
|
||||
values[i] = new float[dimension];
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
values[i][j] = random().nextFloat();
|
||||
}
|
||||
values[i] = randomVector(dimension);
|
||||
}
|
||||
add(iw, i, values[i]);
|
||||
}
|
||||
@ -117,6 +141,14 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
|
||||
float[][] values = new float[][] {new float[] {0, 1, 2}};
|
||||
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
|
||||
VectorUtil.l2normalize(values[0]);
|
||||
}
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
values[0][i] = (float) Math.floor(values[0][i] * 127);
|
||||
}
|
||||
}
|
||||
add(iw, 0, values[0]);
|
||||
assertConsistentGraph(iw, values);
|
||||
iw.commit();
|
||||
@ -133,11 +165,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
float[][] values = randomVectors(numDoc, dimension);
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextBoolean()) {
|
||||
values[i] = new float[dimension];
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
values[i][j] = random().nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(values[i]);
|
||||
values[i] = randomVector(dimension);
|
||||
}
|
||||
add(iw, i, values[i]);
|
||||
if (random().nextInt(10) == 3) {
|
||||
@ -249,16 +277,26 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
float[][] values = new float[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextBoolean()) {
|
||||
values[i] = new float[dimension];
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
values[i][j] = random().nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(values[i]);
|
||||
values[i] = randomVector(dimension);
|
||||
}
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
private float[] randomVector(int dimension) {
|
||||
float[] value = new float[dimension];
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
value[j] = random().nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(value);
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
value[j] = (byte) (value[j] * 127);
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
||||
int[][][] graph = new int[graphValues.numLevels()][][];
|
||||
int size = graphValues.size();
|
||||
@ -285,7 +323,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
IndexWriterConfig config = newIndexWriterConfig();
|
||||
config.setCodec(codec); // test is not compatible with simpletext
|
||||
config.setCodec(float32Codec);
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, config)) {
|
||||
indexData(iw);
|
||||
@ -341,7 +379,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
public void testMultiThreadedSearch() throws Exception {
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
IndexWriterConfig config = newIndexWriterConfig();
|
||||
config.setCodec(codec);
|
||||
config.setCodec(float32Codec);
|
||||
Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, config);
|
||||
indexData(iw);
|
||||
@ -468,7 +506,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
|
||||
values[id],
|
||||
scratch,
|
||||
0f);
|
||||
0);
|
||||
numDocsWithVectors++;
|
||||
}
|
||||
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()
|
||||
|
@ -196,6 +196,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
true);
|
||||
List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS);
|
||||
@ -233,6 +234,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
true);
|
||||
for (DocValuesFieldUpdates update : updates) {
|
||||
@ -295,6 +297,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
true);
|
||||
List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS);
|
||||
@ -362,6 +365,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
true);
|
||||
List<DocValuesFieldUpdates> updates =
|
||||
@ -398,6 +402,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
true);
|
||||
updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true));
|
||||
|
@ -25,6 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
@ -117,6 +118,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() {}
|
||||
|
||||
|
@ -19,6 +19,7 @@ package org.apache.lucene.search;
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.TestVectorUtil.randomVector;
|
||||
|
||||
@ -40,7 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
@ -48,6 +49,7 @@ import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
@ -174,7 +176,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
|
||||
IllegalArgumentException e =
|
||||
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
|
||||
assertEquals("vector dimensions differ: 1!=2", e.getMessage());
|
||||
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,43 +241,38 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||
}
|
||||
|
||||
public void testScoreEuclidean() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(
|
||||
new KnnVectorField("field", new float[] {j, j}, VectorSimilarityFunction.EUCLIDEAN));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
assertEquals(1, reader.leaves().size());
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query rewritten = query.rewrite(reader);
|
||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||
float[][] vectors = new float[5][];
|
||||
for (int j = 0; j < 5; j++) {
|
||||
vectors[j] = new float[] {j, j};
|
||||
}
|
||||
try (Directory d = getIndexStore("field", vectors);
|
||||
IndexReader reader = DirectoryReader.open(d)) {
|
||||
assertEquals(1, reader.leaves().size());
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query rewritten = query.rewrite(reader);
|
||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||
|
||||
// prior to advancing, score is 0
|
||||
assertEquals(-1, scorer.docID());
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
// prior to advancing, score is 0
|
||||
assertEquals(-1, scorer.docID());
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
assertEquals(0, scorer.getMaxScore(0), 0);
|
||||
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
assertEquals(0, scorer.getMaxScore(0), 0);
|
||||
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||
|
||||
DocIdSetIterator it = scorer.iterator();
|
||||
assertEquals(3, it.cost());
|
||||
assertEquals(1, it.nextDoc());
|
||||
assertEquals(1 / 6f, scorer.score(), 0);
|
||||
assertEquals(3, it.advance(3));
|
||||
assertEquals(1 / 2f, scorer.score(), 0);
|
||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
}
|
||||
DocIdSetIterator it = scorer.iterator();
|
||||
assertEquals(3, it.cost());
|
||||
assertEquals(1, it.nextDoc());
|
||||
assertEquals(1 / 6f, scorer.score(), 0);
|
||||
assertEquals(3, it.advance(3));
|
||||
assertEquals(1 / 2f, scorer.score(), 0);
|
||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
}
|
||||
}
|
||||
|
||||
@ -764,9 +761,18 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
VectorEncoding encoding = randomVectorEncoding();
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField(field, contents[i]));
|
||||
if (encoding == VectorEncoding.BYTE) {
|
||||
BytesRef v = new BytesRef(new byte[contents[i].length]);
|
||||
for (int j = 0; j < v.length; j++) {
|
||||
v.bytes[j] = (byte) contents[i][j];
|
||||
}
|
||||
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
|
||||
} else {
|
||||
doc.add(new KnnVectorField(field, contents[i]));
|
||||
}
|
||||
doc.add(new StringField("id", "id" + i, Field.Store.YES));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
@ -908,4 +914,8 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
||||
return 31 * classHash() + docs.hashCode();
|
||||
}
|
||||
}
|
||||
|
||||
private VectorEncoding randomVectorEncoding() {
|
||||
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
}
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.PointValues;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.SortField.Type;
|
||||
import org.apache.lucene.store.Directory;
|
||||
@ -1126,6 +1127,7 @@ public class TestSortOptimization extends LuceneTestCase {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
fi.isSoftDeletesField());
|
||||
newInfos[i] = noIndexFI;
|
||||
|
@ -18,6 +18,7 @@ package org.apache.lucene.util;
|
||||
|
||||
import java.util.Random;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
||||
public class TestVectorUtil extends LuceneTestCase {
|
||||
|
||||
@ -130,6 +131,23 @@ public class TestVectorUtil extends LuceneTestCase {
|
||||
return u;
|
||||
}
|
||||
|
||||
private static BytesRef negative(BytesRef v) {
|
||||
BytesRef u = new BytesRef(new byte[v.length]);
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
// what is (byte) -(-128)? 127?
|
||||
u.bytes[i] = (byte) -v.bytes[i];
|
||||
}
|
||||
return u;
|
||||
}
|
||||
|
||||
private static float l2(BytesRef v) {
|
||||
float l2 = 0;
|
||||
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||
l2 += v.bytes[i] * v.bytes[i];
|
||||
}
|
||||
return l2;
|
||||
}
|
||||
|
||||
private static float[] randomVector() {
|
||||
return randomVector(random().nextInt(100) + 1);
|
||||
}
|
||||
@ -142,4 +160,88 @@ public class TestVectorUtil extends LuceneTestCase {
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
private static BytesRef randomVectorBytes() {
|
||||
BytesRef v = TestUtil.randomBinaryTerm(random(), TestUtil.nextInt(random(), 1, 100));
|
||||
// clip at -127 to avoid overflow
|
||||
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||
if (v.bytes[i] == -128) {
|
||||
v.bytes[i] = -127;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
public void testBasicDotProductBytes() {
|
||||
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
|
||||
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
|
||||
assertEquals(5, VectorUtil.dotProduct(a, b), 0);
|
||||
assertEquals(5 / (3f * (1 << 15)), VectorUtil.dotProductScore(a, b), DELTA);
|
||||
}
|
||||
|
||||
public void testSelfDotProductBytes() {
|
||||
// the dot product of a vector with itself is equal to the sum of the squares of its components
|
||||
BytesRef v = randomVectorBytes();
|
||||
assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
|
||||
}
|
||||
|
||||
public void testOrthogonalDotProductBytes() {
|
||||
// the dot product of two perpendicular vectors is 0
|
||||
byte[] v = new byte[4];
|
||||
v[0] = (byte) random().nextInt(100);
|
||||
v[1] = (byte) random().nextInt(100);
|
||||
v[2] = v[1];
|
||||
v[3] = (byte) -v[0];
|
||||
// also test computing using BytesRef with nonzero offset
|
||||
assertEquals(0, VectorUtil.dotProduct(new BytesRef(v, 0, 2), new BytesRef(v, 2, 2)), DELTA);
|
||||
}
|
||||
|
||||
public void testSelfSquareDistanceBytes() {
|
||||
// the l2 distance of a vector with itself is zero
|
||||
BytesRef v = randomVectorBytes();
|
||||
assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
|
||||
}
|
||||
|
||||
public void testBasicSquareDistanceBytes() {
|
||||
assertEquals(
|
||||
12,
|
||||
VectorUtil.squareDistance(
|
||||
new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-1, 0, 5})),
|
||||
0);
|
||||
}
|
||||
|
||||
public void testRandomSquareDistanceBytes() {
|
||||
// the square distance of a vector with its inverse is equal to four times the sum of squares of
|
||||
// its components
|
||||
BytesRef v = randomVectorBytes();
|
||||
BytesRef u = negative(v);
|
||||
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
|
||||
}
|
||||
|
||||
public void testBasicCosineBytes() {
|
||||
assertEquals(
|
||||
0.11952f,
|
||||
VectorUtil.cosine(new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-10, 0, 5})),
|
||||
DELTA);
|
||||
}
|
||||
|
||||
public void testSelfCosineBytes() {
|
||||
// the dot product of a vector with itself is always equal to 1
|
||||
BytesRef v = randomVectorBytes();
|
||||
// ensure the vector is non-zero so that cosine is defined
|
||||
v.bytes[0] = (byte) (random().nextInt(126) + 1);
|
||||
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
|
||||
}
|
||||
|
||||
public void testOrthogonalCosineBytes() {
|
||||
// the cosine of two perpendicular vectors is 0
|
||||
float[] v = new float[2];
|
||||
v[0] = random().nextInt(100);
|
||||
// ensure the vector is non-zero so that cosine is defined
|
||||
v[1] = random().nextInt(1, 100);
|
||||
float[] u = new float[2];
|
||||
u[0] = v[1];
|
||||
u[1] = -v[0];
|
||||
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
|
||||
}
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.ConstantScoreScorer;
|
||||
import org.apache.lucene.search.ConstantScoreWeight;
|
||||
@ -101,6 +102,7 @@ public class KnnGraphTester {
|
||||
private int beamWidth;
|
||||
private int maxConn;
|
||||
private VectorSimilarityFunction similarityFunction;
|
||||
private VectorEncoding vectorEncoding;
|
||||
private FixedBitSet matchDocs;
|
||||
private float selectivity;
|
||||
private boolean prefilter;
|
||||
@ -113,6 +115,7 @@ public class KnnGraphTester {
|
||||
topK = 100;
|
||||
fanout = topK;
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
vectorEncoding = VectorEncoding.FLOAT32;
|
||||
selectivity = 1f;
|
||||
prefilter = false;
|
||||
}
|
||||
@ -195,12 +198,30 @@ public class KnnGraphTester {
|
||||
case "-docs":
|
||||
docVectorsPath = Paths.get(args[++iarg]);
|
||||
break;
|
||||
case "-encoding":
|
||||
String encoding = args[++iarg];
|
||||
switch (encoding) {
|
||||
case "byte":
|
||||
vectorEncoding = VectorEncoding.BYTE;
|
||||
break;
|
||||
case "float32":
|
||||
vectorEncoding = VectorEncoding.FLOAT32;
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only");
|
||||
}
|
||||
break;
|
||||
case "-metric":
|
||||
String metric = args[++iarg];
|
||||
if (metric.equals("euclidean")) {
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
} else if (metric.equals("angular") == false) {
|
||||
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
|
||||
switch (metric) {
|
||||
case "euclidean":
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
break;
|
||||
case "angular":
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
|
||||
}
|
||||
break;
|
||||
case "-forceMerge":
|
||||
@ -229,7 +250,7 @@ public class KnnGraphTester {
|
||||
if (operation == null && reindex == false) {
|
||||
usage();
|
||||
}
|
||||
if (prefilter == true && selectivity == 1f) {
|
||||
if (prefilter && selectivity == 1f) {
|
||||
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
|
||||
}
|
||||
indexPath = Paths.get(formatIndexPath(docVectorsPath));
|
||||
@ -248,7 +269,9 @@ public class KnnGraphTester {
|
||||
if (docVectorsPath == null) {
|
||||
throw new IllegalArgumentException("missing -docs arg");
|
||||
}
|
||||
matchDocs = generateRandomBitSet(numDocs, selectivity);
|
||||
if (selectivity < 1) {
|
||||
matchDocs = generateRandomBitSet(numDocs, selectivity);
|
||||
}
|
||||
if (outputPath != null) {
|
||||
testSearch(indexPath, queryPath, outputPath, null);
|
||||
} else {
|
||||
@ -285,14 +308,17 @@ public class KnnGraphTester {
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void dumpGraph(Path docsPath) throws IOException {
|
||||
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
|
||||
RandomAccessVectorValues values = vectors.randomAccess();
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0);
|
||||
HnswGraphBuilder<float[]> builder =
|
||||
(HnswGraphBuilder<float[]>)
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0);
|
||||
// start at node 1
|
||||
for (int i = 1; i < numDocs; i++) {
|
||||
builder.addGraphNode(i, values.vectorValue(i));
|
||||
builder.addGraphNode(i, values);
|
||||
System.out.println("\nITERATION " + i);
|
||||
dumpGraph(builder.hnsw);
|
||||
}
|
||||
@ -375,13 +401,8 @@ public class KnnGraphTester {
|
||||
throws IOException {
|
||||
TopDocs[] results = new TopDocs[numIters];
|
||||
long elapsed, totalCpuTime, totalVisited = 0;
|
||||
try (FileChannel q = FileChannel.open(queryPath)) {
|
||||
int bufferSize = numIters * dim * Float.BYTES;
|
||||
FloatBuffer targets =
|
||||
q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
float[] target = new float[dim];
|
||||
try (FileChannel input = FileChannel.open(queryPath)) {
|
||||
VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding, numIters);
|
||||
if (quiet == false) {
|
||||
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
|
||||
}
|
||||
@ -392,21 +413,21 @@ public class KnnGraphTester {
|
||||
DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
numDocs = reader.maxDoc();
|
||||
Query bitSetQuery = new BitSetQuery(matchDocs);
|
||||
Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
// warm up
|
||||
targets.get(target);
|
||||
float[] target = targetReader.next();
|
||||
if (prefilter) {
|
||||
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
||||
} else {
|
||||
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
||||
}
|
||||
}
|
||||
targets.position(0);
|
||||
targetReader.reset();
|
||||
start = System.nanoTime();
|
||||
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
targets.get(target);
|
||||
float[] target = targetReader.next();
|
||||
if (prefilter) {
|
||||
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
||||
} else {
|
||||
@ -414,10 +435,12 @@ public class KnnGraphTester {
|
||||
doKnnVectorQuery(
|
||||
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
||||
|
||||
results[i].scoreDocs =
|
||||
Arrays.stream(results[i].scoreDocs)
|
||||
.filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc))
|
||||
.toArray(ScoreDoc[]::new);
|
||||
if (matchDocs != null) {
|
||||
results[i].scoreDocs =
|
||||
Arrays.stream(results[i].scoreDocs)
|
||||
.filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
|
||||
.toArray(ScoreDoc[]::new);
|
||||
}
|
||||
}
|
||||
}
|
||||
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
|
||||
@ -425,7 +448,14 @@ public class KnnGraphTester {
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
totalVisited += results[i].totalHits.value;
|
||||
for (ScoreDoc doc : results[i].scoreDocs) {
|
||||
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
|
||||
if (doc.doc != NO_MORE_DOCS) {
|
||||
// there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
|
||||
// in some degenerate case (like input query has NaN in it?) that causes no results to
|
||||
// be returned from HNSW search?
|
||||
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
|
||||
} else {
|
||||
System.out.println("NO_MORE_DOCS!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -477,6 +507,78 @@ public class KnnGraphTester {
|
||||
}
|
||||
}
|
||||
|
||||
private abstract static class VectorReader {
|
||||
final float[] target;
|
||||
final ByteBuffer bytes;
|
||||
|
||||
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int n)
|
||||
throws IOException {
|
||||
int bufferSize = n * dim * vectorEncoding.byteSize;
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> new VectorReaderByte(input, dim, bufferSize);
|
||||
case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
|
||||
};
|
||||
}
|
||||
|
||||
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
|
||||
bytes =
|
||||
input.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize).order(ByteOrder.LITTLE_ENDIAN);
|
||||
target = new float[dim];
|
||||
}
|
||||
|
||||
void reset() {
|
||||
bytes.position(0);
|
||||
}
|
||||
|
||||
abstract float[] next();
|
||||
}
|
||||
|
||||
private static class VectorReaderFloat32 extends VectorReader {
|
||||
private final FloatBuffer floats;
|
||||
|
||||
VectorReaderFloat32(FileChannel input, int dim, int bufferSize) throws IOException {
|
||||
super(input, dim, bufferSize);
|
||||
floats = bytes.asFloatBuffer();
|
||||
}
|
||||
|
||||
@Override
|
||||
void reset() {
|
||||
super.reset();
|
||||
floats.position(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] next() {
|
||||
floats.get(target);
|
||||
return target;
|
||||
}
|
||||
}
|
||||
|
||||
private static class VectorReaderByte extends VectorReader {
|
||||
private byte[] scratch;
|
||||
private BytesRef bytesRef;
|
||||
|
||||
VectorReaderByte(FileChannel input, int dim, int bufferSize) throws IOException {
|
||||
super(input, dim, bufferSize);
|
||||
scratch = new byte[dim];
|
||||
bytesRef = new BytesRef(scratch);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] next() {
|
||||
bytes.get(scratch);
|
||||
for (int i = 0; i < scratch.length; i++) {
|
||||
target[i] = scratch[i];
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
BytesRef nextBytes() {
|
||||
bytes.get(scratch);
|
||||
return bytesRef;
|
||||
}
|
||||
}
|
||||
|
||||
private static TopDocs doKnnVectorQuery(
|
||||
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
|
||||
throws IOException {
|
||||
@ -529,7 +631,9 @@ public class KnnGraphTester {
|
||||
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
|
||||
return readNN(nnPath);
|
||||
} else {
|
||||
int[][] nn = computeNN(docPath, queryPath);
|
||||
// TODO: enable computing NN from high precision vectors when
|
||||
// checking low-precision recall
|
||||
int[][] nn = computeNN(docPath, queryPath, vectorEncoding);
|
||||
if (selectivity == 1f) {
|
||||
writeNN(nn, nnPath);
|
||||
}
|
||||
@ -589,52 +693,37 @@ public class KnnGraphTester {
|
||||
return bitSet;
|
||||
}
|
||||
|
||||
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
|
||||
private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding)
|
||||
throws IOException {
|
||||
int[][] result = new int[numIters][];
|
||||
if (quiet == false) {
|
||||
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
|
||||
}
|
||||
try (FileChannel in = FileChannel.open(docPath);
|
||||
FileChannel qIn = FileChannel.open(queryPath)) {
|
||||
FloatBuffer queries =
|
||||
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
float[] vector = new float[dim];
|
||||
float[] query = new float[dim];
|
||||
VectorReader docReader = VectorReader.create(in, dim, encoding, numDocs);
|
||||
VectorReader queryReader = VectorReader.create(qIn, dim, encoding, numIters);
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
queries.get(query);
|
||||
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
||||
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
||||
int offset = 0;
|
||||
int j = 0;
|
||||
// System.out.println("totalBytes=" + totalBytes);
|
||||
while (j < numDocs) {
|
||||
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
|
||||
FloatBuffer vectors =
|
||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
offset += blockSize;
|
||||
NeighborQueue queue = new NeighborQueue(topK, false);
|
||||
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
||||
vectors.get(vector);
|
||||
float d = similarityFunction.compare(query, vector);
|
||||
if (matchDocs == null || matchDocs.get(j)) {
|
||||
queue.insertWithOverflow(j, d);
|
||||
}
|
||||
}
|
||||
result[i] = new int[topK];
|
||||
for (int k = topK - 1; k >= 0; k--) {
|
||||
result[i][k] = queue.topNode();
|
||||
queue.pop();
|
||||
// System.out.print(" " + n);
|
||||
}
|
||||
if (quiet == false && (i + 1) % 10 == 0) {
|
||||
System.out.print(" " + (i + 1));
|
||||
System.out.flush();
|
||||
float[] query = queryReader.next();
|
||||
NeighborQueue queue = new NeighborQueue(topK, false);
|
||||
for (int j = 0; j < numDocs; j++) {
|
||||
float[] doc = docReader.next();
|
||||
float d = similarityFunction.compare(query, doc);
|
||||
if (matchDocs == null || matchDocs.get(j)) {
|
||||
queue.insertWithOverflow(j, d);
|
||||
}
|
||||
}
|
||||
docReader.reset();
|
||||
result[i] = new int[topK];
|
||||
for (int k = topK - 1; k >= 0; k--) {
|
||||
result[i][k] = queue.topNode();
|
||||
queue.pop();
|
||||
// System.out.print(" " + n);
|
||||
}
|
||||
if (quiet == false && (i + 1) % 10 == 0) {
|
||||
System.out.print(" " + (i + 1));
|
||||
System.out.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@ -651,37 +740,29 @@ public class KnnGraphTester {
|
||||
});
|
||||
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||
iwc.setRAMBufferSizeMB(1994d);
|
||||
iwc.setUseCompoundFile(false);
|
||||
// iwc.setMaxBufferedDocs(10000);
|
||||
|
||||
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
|
||||
FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
|
||||
if (quiet == false) {
|
||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||
System.out.println("creating index in " + indexPath);
|
||||
}
|
||||
long start = System.nanoTime();
|
||||
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
|
||||
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
||||
|
||||
try (FSDirectory dir = FSDirectory.open(indexPath);
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
float[] vector = new float[dim];
|
||||
try (FileChannel in = FileChannel.open(docsPath)) {
|
||||
int i = 0;
|
||||
while (i < numDocs) {
|
||||
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
|
||||
FloatBuffer vectors =
|
||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
offset += blockSize;
|
||||
for (; vectors.hasRemaining() && i < numDocs; i++) {
|
||||
vectors.get(vector);
|
||||
Document doc = new Document();
|
||||
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
|
||||
doc.add(new KnnVectorField(KNN_FIELD, vector, fieldType));
|
||||
doc.add(new StoredField(ID_FIELD, i));
|
||||
iw.addDocument(doc);
|
||||
VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding, numDocs);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> doc.add(
|
||||
new KnnVectorField(
|
||||
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
|
||||
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
|
||||
}
|
||||
doc.add(new StoredField(ID_FIELD, i));
|
||||
iw.addDocument(doc);
|
||||
}
|
||||
if (quiet == false) {
|
||||
System.out.println("Done indexing " + numDocs + " documents; now flush");
|
||||
|
@ -31,6 +31,7 @@ class MockVectorValues extends VectorValues
|
||||
protected final float[][] denseValues;
|
||||
protected final float[][] values;
|
||||
private final int numVectors;
|
||||
private final BytesRef binaryValue;
|
||||
|
||||
private int pos = -1;
|
||||
|
||||
@ -47,6 +48,9 @@ class MockVectorValues extends VectorValues
|
||||
}
|
||||
numVectors = count;
|
||||
scratch = new float[dimension];
|
||||
// used by tests that build a graph from bytes rather than floats
|
||||
binaryValue = new BytesRef(dimension);
|
||||
binaryValue.length = dimension;
|
||||
}
|
||||
|
||||
public MockVectorValues copy() {
|
||||
@ -89,7 +93,11 @@ class MockVectorValues extends VectorValues
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
return null;
|
||||
float[] value = vectorValue(targetOrd);
|
||||
for (int i = 0; i < value.length; i++) {
|
||||
binaryValue.bytes[i] = (byte) value[i];
|
||||
}
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private boolean seek(int target) {
|
||||
|
@ -18,6 +18,7 @@
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
@ -43,6 +44,7 @@ import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
@ -60,25 +62,38 @@ import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.junit.Before;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnswGraph extends LuceneTestCase {
|
||||
|
||||
VectorSimilarityFunction similarityFunction;
|
||||
VectorEncoding vectorEncoding;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
|
||||
vectorEncoding =
|
||||
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
|
||||
} else {
|
||||
vectorEncoding = VectorEncoding.FLOAT32;
|
||||
}
|
||||
}
|
||||
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
public void testReadWrite() throws IOException {
|
||||
int dim = random().nextInt(100) + 1;
|
||||
int nDoc = random().nextInt(100) + 1;
|
||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
|
||||
int M = random().nextInt(10) + 5;
|
||||
int M = random().nextInt(4) + 2;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed);
|
||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
|
||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
// Recreate the graph while indexing with the same random seed and write it out
|
||||
@ -131,6 +146,10 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
private VectorEncoding randomVectorEncoding() {
|
||||
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
}
|
||||
|
||||
// test that sorted index returns the same search results are unsorted
|
||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||
int dim = random().nextInt(10) + 3;
|
||||
@ -250,24 +269,27 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
// oriented in the right directions
|
||||
public void testAknnDiverse() throws IOException {
|
||||
int nDoc = 100;
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt());
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
Integer.MAX_VALUE);
|
||||
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
int sum = 0;
|
||||
for (int node : nodes) {
|
||||
sum += node;
|
||||
@ -289,23 +311,26 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
public void testSearchWithAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// the first 10 docs must not be deleted to ensure the expected recall
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
int sum = 0;
|
||||
for (int node : nodes) {
|
||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||
@ -319,9 +344,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
// Only mark a few vectors as accepted
|
||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
||||
@ -333,10 +360,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
int numAccepted = acceptOrds.cardinality();
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
getTargetVector(),
|
||||
numAccepted,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
@ -347,12 +375,17 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
private float[] getTargetVector() {
|
||||
return new float[] {1, 0};
|
||||
}
|
||||
|
||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||
int nDoc = 1000;
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
// Skip over half of the documents that are closest to the query vector
|
||||
@ -362,15 +395,16 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
}
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
VectorEncoding.FLOAT32,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
int[] nodes = nn.nodes();
|
||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
int sum = 0;
|
||||
for (int node : nodes) {
|
||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||
@ -383,20 +417,23 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
|
||||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
int topK = 50;
|
||||
int visitedLimit = topK + random().nextInt(5);
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
new float[] {1, 0},
|
||||
getTargetVector(),
|
||||
topK,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
visitedLimit);
|
||||
@ -406,54 +443,68 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
}
|
||||
|
||||
public void testHnswGraphBuilderInvalid() {
|
||||
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
|
||||
expectThrows(
|
||||
NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0));
|
||||
// M must be > 0
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
new HnswGraphBuilder(
|
||||
HnswGraphBuilder.create(
|
||||
new RandomVectorValues(1, 1, random()),
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
0,
|
||||
10,
|
||||
0));
|
||||
// beamWidth must be > 0
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
new HnswGraphBuilder(
|
||||
HnswGraphBuilder.create(
|
||||
new RandomVectorValues(1, 1, random()),
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
10,
|
||||
0,
|
||||
0));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testDiversity() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
// Some carefully checked test cases with simple 2d vectors on the unit circle:
|
||||
MockVectorValues vectors =
|
||||
new MockVectorValues(
|
||||
new float[][] {
|
||||
unitVector2d(0.5),
|
||||
unitVector2d(0.75),
|
||||
unitVector2d(0.2),
|
||||
unitVector2d(0.9),
|
||||
unitVector2d(0.8),
|
||||
unitVector2d(0.77),
|
||||
});
|
||||
float[][] values = {
|
||||
unitVector2d(0.5),
|
||||
unitVector2d(0.75),
|
||||
unitVector2d(0.2),
|
||||
unitVector2d(0.9),
|
||||
unitVector2d(0.8),
|
||||
unitVector2d(0.77),
|
||||
};
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (float[] v : values) {
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
v[i] *= 127;
|
||||
}
|
||||
}
|
||||
}
|
||||
MockVectorValues vectors = new MockVectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
builder.addGraphNode(1, vectors.vectorValue(1));
|
||||
builder.addGraphNode(2, vectors.vectorValue(2));
|
||||
builder.addGraphNode(1, vectors);
|
||||
builder.addGraphNode(2, vectors);
|
||||
// now every node has tried to attach every other node as a neighbor, but
|
||||
// some were excluded based on diversity check.
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
|
||||
builder.addGraphNode(3, vectors.vectorValue(3));
|
||||
builder.addGraphNode(3, vectors);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
// we added 3 here
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
|
||||
@ -461,7 +512,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
assertLevel0Neighbors(builder.hnsw, 3, 1);
|
||||
|
||||
// supplant an existing neighbor
|
||||
builder.addGraphNode(4, vectors.vectorValue(4));
|
||||
builder.addGraphNode(4, vectors);
|
||||
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
|
||||
@ -470,7 +521,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
|
||||
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
|
||||
|
||||
builder.addGraphNode(5, vectors.vectorValue(5));
|
||||
builder.addGraphNode(5, vectors);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
@ -494,29 +545,46 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
public void testRandom() throws IOException {
|
||||
int size = atLeast(100);
|
||||
int dim = atLeast(10);
|
||||
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||
RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
|
||||
int topK = 5;
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||
|
||||
int totalMatches = 0;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
float[] query = randomVector(random(), dim);
|
||||
NeighborQueue actual =
|
||||
NeighborQueue actual;
|
||||
float[] query;
|
||||
BytesRef bQuery = null;
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
query = randomVector8(random(), dim);
|
||||
bQuery = toBytesRef(query);
|
||||
} else {
|
||||
query = randomVector(random(), dim);
|
||||
}
|
||||
actual =
|
||||
HnswGraphSearcher.search(
|
||||
query, 100, vectors, similarityFunction, hnsw, acceptOrds, Integer.MAX_VALUE);
|
||||
query,
|
||||
100,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
while (actual.size() > topK) {
|
||||
actual.pop();
|
||||
}
|
||||
NeighborQueue expected = new NeighborQueue(topK, false);
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
|
||||
} else {
|
||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||
}
|
||||
if (expected.size() > topK) {
|
||||
expected.pop();
|
||||
}
|
||||
@ -553,12 +621,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
private final BytesRef binaryValue;
|
||||
|
||||
int doc = -1;
|
||||
|
||||
CircularVectorValues(int size) {
|
||||
this.size = size;
|
||||
value = new float[2];
|
||||
binaryValue = new BytesRef(new byte[2]);
|
||||
}
|
||||
|
||||
public CircularVectorValues copy() {
|
||||
@ -617,7 +687,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int ord) {
|
||||
return null;
|
||||
float[] vectorValue = vectorValue(ord);
|
||||
for (int i = 0; i < vectorValue.length; i++) {
|
||||
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
|
||||
}
|
||||
return binaryValue;
|
||||
}
|
||||
}
|
||||
|
||||
@ -648,8 +722,9 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
if (uDoc == NO_MORE_DOCS) {
|
||||
break;
|
||||
}
|
||||
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f);
|
||||
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
|
||||
}
|
||||
}
|
||||
|
||||
@ -657,7 +732,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
static class RandomVectorValues extends MockVectorValues {
|
||||
|
||||
RandomVectorValues(int size, int dimension, Random random) {
|
||||
super(createRandomVectors(size, dimension, random));
|
||||
super(createRandomVectors(size, dimension, null, random));
|
||||
}
|
||||
|
||||
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
||||
super(createRandomVectors(size, dimension, vectorEncoding, random));
|
||||
}
|
||||
|
||||
RandomVectorValues(RandomVectorValues other) {
|
||||
@ -669,11 +748,21 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
return new RandomVectorValues(this);
|
||||
}
|
||||
|
||||
private static float[][] createRandomVectors(int size, int dimension, Random random) {
|
||||
private static float[][] createRandomVectors(
|
||||
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
||||
float[][] vectors = new float[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
vectors[offset] = randomVector(random, dimension);
|
||||
}
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (float[] vector : vectors) {
|
||||
if (vector != null) {
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
vector[i] = (byte) (127 * vector[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vectors;
|
||||
}
|
||||
}
|
||||
@ -701,8 +790,19 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||
float[] vec = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
vec[i] = random.nextFloat();
|
||||
if (random.nextBoolean()) {
|
||||
vec[i] = -vec[i];
|
||||
}
|
||||
}
|
||||
VectorUtil.l2normalize(vec);
|
||||
return vec;
|
||||
}
|
||||
|
||||
private static float[] randomVector8(Random random, int dim) {
|
||||
float[] fvec = randomVector(random, dim);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
fvec[i] *= 127;
|
||||
}
|
||||
return fvec;
|
||||
}
|
||||
}
|
||||
|
@ -34,8 +34,10 @@ import org.apache.lucene.index.SortedNumericDocValues;
|
||||
import org.apache.lucene.index.SortedSetDocValues;
|
||||
import org.apache.lucene.index.StoredFieldVisitor;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.Version;
|
||||
@ -97,6 +99,7 @@ public class TermVectorLeafReader extends LeafReader {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false);
|
||||
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
|
||||
@ -166,6 +169,12 @@ public class TermVectorLeafReader extends LeafReader {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {}
|
||||
|
||||
|
@ -37,6 +37,7 @@ import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.index.*;
|
||||
import org.apache.lucene.search.Collector;
|
||||
import org.apache.lucene.search.CollectorManager;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
@ -514,6 +515,7 @@ public class MemoryIndex {
|
||||
fieldType.pointIndexDimensionCount(),
|
||||
fieldType.pointNumBytes(),
|
||||
fieldType.vectorDimension(),
|
||||
fieldType.vectorEncoding(),
|
||||
fieldType.vectorSimilarityFunction(),
|
||||
false);
|
||||
}
|
||||
@ -546,6 +548,7 @@ public class MemoryIndex {
|
||||
info.fieldInfo.getPointIndexDimensionCount(),
|
||||
info.fieldInfo.getPointNumBytes(),
|
||||
info.fieldInfo.getVectorDimension(),
|
||||
info.fieldInfo.getVectorEncoding(),
|
||||
info.fieldInfo.getVectorSimilarityFunction(),
|
||||
info.fieldInfo.isSoftDeletesField());
|
||||
} else if (existingDocValuesType != docValuesType) {
|
||||
@ -1371,6 +1374,12 @@ public class MemoryIndex {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
// no-op
|
||||
|
@ -29,6 +29,7 @@ import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
@ -61,7 +62,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
return delegate.addField(fieldInfo);
|
||||
}
|
||||
|
||||
@ -131,6 +132,18 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
||||
return hits;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null && fi.getVectorDimension() > 0;
|
||||
assert acceptDocs != null;
|
||||
TopDocs hits = delegate.searchExhaustively(field, target, k, acceptDocs);
|
||||
assert hits != null;
|
||||
assert hits.scoreDocs.length <= k;
|
||||
return hits;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
delegate.close();
|
||||
|
@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexableFieldType;
|
||||
import org.apache.lucene.index.PointValues;
|
||||
import org.apache.lucene.index.SegmentInfo;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.internal.tests.IndexPackageAccess;
|
||||
@ -305,6 +306,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
||||
fieldType.pointIndexDimensionCount(),
|
||||
fieldType.pointNumBytes(),
|
||||
fieldType.vectorDimension(),
|
||||
fieldType.vectorEncoding(),
|
||||
fieldType.vectorSimilarityFunction(),
|
||||
field.equals(softDeletesField));
|
||||
addAttributes(fi);
|
||||
@ -353,7 +355,8 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
||||
int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS);
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
|
||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
||||
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());
|
||||
type.setVectorAttributes(dimension, encoding, similarityFunction);
|
||||
}
|
||||
|
||||
return type;
|
||||
@ -422,6 +425,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false);
|
||||
}
|
||||
|
@ -360,6 +360,7 @@ abstract class BaseIndexFileFormatTestCase extends LuceneTestCase {
|
||||
proto.getPointIndexDimensionCount(),
|
||||
proto.getPointNumBytes(),
|
||||
proto.getVectorDimension(),
|
||||
proto.getVectorEncoding(),
|
||||
proto.getVectorSimilarityFunction(),
|
||||
proto.isSoftDeletesField());
|
||||
|
||||
|
@ -41,6 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.Sort;
|
||||
@ -51,8 +52,10 @@ import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.junit.Before;
|
||||
|
||||
/**
|
||||
* Base class aiming at testing {@link KnnVectorsFormat vectors formats}. To test a new format, all
|
||||
@ -63,9 +66,21 @@ import org.apache.lucene.util.VectorUtil;
|
||||
*/
|
||||
public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
|
||||
|
||||
private VectorEncoding vectorEncoding;
|
||||
private VectorSimilarityFunction similarityFunction;
|
||||
|
||||
@Before
|
||||
public void init() {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = randomSimilarity();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void addRandomFields(Document doc) {
|
||||
doc.add(new KnnVectorField("v2", randomVector(30), VectorSimilarityFunction.EUCLIDEAN));
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
|
||||
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
|
||||
}
|
||||
}
|
||||
|
||||
public void testFieldConstructor() {
|
||||
@ -133,8 +148,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
||||
String errMsg =
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=3, vector similarity function=DOT_PRODUCT";
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
|
||||
assertEquals(errMsg, expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -170,8 +185,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
||||
String errMsg =
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN";
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN";
|
||||
assertEquals(errMsg, expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -190,8 +205,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=1, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -211,8 +226,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN",
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -311,8 +326,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -333,8 +348,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -358,8 +373,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException.class,
|
||||
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -384,8 +399,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException.class,
|
||||
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -408,8 +423,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -432,8 +447,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
||||
assertEquals(
|
||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
|
||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
||||
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
|
||||
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||
expected.getMessage());
|
||||
}
|
||||
}
|
||||
@ -596,12 +611,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
int[] fieldDocCounts = new int[numFields];
|
||||
double[] fieldTotals = new double[numFields];
|
||||
int[] fieldDims = new int[numFields];
|
||||
VectorSimilarityFunction[] fieldSearchStrategies = new VectorSimilarityFunction[numFields];
|
||||
VectorSimilarityFunction[] fieldSimilarityFunctions = new VectorSimilarityFunction[numFields];
|
||||
VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields];
|
||||
for (int i = 0; i < numFields; i++) {
|
||||
fieldDims[i] = random().nextInt(20) + 1;
|
||||
fieldSearchStrategies[i] =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
||||
fieldSimilarityFunctions[i] = randomSimilarity();
|
||||
fieldVectorEncodings[i] = randomVectorEncoding();
|
||||
}
|
||||
try (Directory dir = newDirectory();
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
|
||||
@ -610,15 +625,23 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
for (int field = 0; field < numFields; field++) {
|
||||
String fieldName = "int" + field;
|
||||
if (random().nextInt(100) == 17) {
|
||||
float[] v = randomVector(fieldDims[field]);
|
||||
doc.add(new KnnVectorField(fieldName, v, fieldSearchStrategies[field]));
|
||||
switch (fieldVectorEncodings[field]) {
|
||||
case BYTE -> {
|
||||
BytesRef b = randomVector8(fieldDims[field]);
|
||||
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
|
||||
fieldTotals[field] += b.bytes[b.offset];
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
float[] v = randomVector(fieldDims[field]);
|
||||
doc.add(new KnnVectorField(fieldName, v, fieldSimilarityFunctions[field]));
|
||||
fieldTotals[field] += v[0];
|
||||
}
|
||||
}
|
||||
fieldDocCounts[field]++;
|
||||
fieldTotals[field] += v[0];
|
||||
}
|
||||
}
|
||||
w.addDocument(doc);
|
||||
}
|
||||
|
||||
try (IndexReader r = w.getReader()) {
|
||||
for (int field = 0; field < numFields; field++) {
|
||||
int docCount = 0;
|
||||
@ -634,12 +657,29 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
}
|
||||
}
|
||||
assertEquals(fieldDocCounts[field], docCount);
|
||||
assertEquals(fieldTotals[field], checksum, 1e-5);
|
||||
// Account for quantization done when indexing fields w/BYTE encoding
|
||||
double delta = fieldVectorEncodings[field] == VectorEncoding.BYTE ? numDocs * 0.01 : 1e-5;
|
||||
assertEquals(fieldTotals[field], checksum, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private VectorSimilarityFunction randomSimilarity() {
|
||||
return VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
||||
}
|
||||
|
||||
private VectorEncoding randomVectorEncoding() {
|
||||
Codec codec = getCodec();
|
||||
if (codec.knnVectorsFormat().currentVersion()
|
||||
>= Codec.forName("Lucene94").knnVectorsFormat().currentVersion()) {
|
||||
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
} else {
|
||||
return VectorEncoding.FLOAT32;
|
||||
}
|
||||
}
|
||||
|
||||
public void testIndexedValueNotAliased() throws Exception {
|
||||
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
|
||||
// calls to IndexWriter.addDocument.
|
||||
@ -742,7 +782,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
assertEquals(3, vectorValues3.dimension());
|
||||
assertEquals(1, vectorValues3.size());
|
||||
vectorValues3.nextDoc();
|
||||
assertEquals(1f, vectorValues3.vectorValue()[0], 0);
|
||||
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
|
||||
}
|
||||
}
|
||||
@ -775,9 +815,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
if (random().nextBoolean() && values[i] != null) {
|
||||
// sometimes use a shared scratch array
|
||||
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
|
||||
add(iw, fieldName, i, scratch, VectorSimilarityFunction.EUCLIDEAN);
|
||||
add(iw, fieldName, i, scratch, similarityFunction);
|
||||
} else {
|
||||
add(iw, fieldName, i, values[i], VectorSimilarityFunction.EUCLIDEAN);
|
||||
add(iw, fieldName, i, values[i], similarityFunction);
|
||||
}
|
||||
if (random().nextInt(10) == 2) {
|
||||
// sometimes delete a random document
|
||||
@ -898,7 +938,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
float[][] id2value = new float[numDoc][];
|
||||
int[] id2ord = new int[numDoc];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
int id = random().nextInt(numDoc);
|
||||
float[] value;
|
||||
@ -909,7 +948,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
value = null;
|
||||
}
|
||||
id2value[id] = value;
|
||||
id2ord[id] = i;
|
||||
add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
@ -1007,6 +1045,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
return v;
|
||||
}
|
||||
|
||||
private BytesRef randomVector8(int dim) {
|
||||
float[] v = randomVector(dim);
|
||||
byte[] b = new byte[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
b[i] = (byte) (v[i] * 127);
|
||||
}
|
||||
return new BytesRef(b);
|
||||
}
|
||||
|
||||
public void testCheckIndexIncludesVectors() throws Exception {
|
||||
try (Directory dir = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
@ -1041,6 +1088,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
assertEquals(3, VectorSimilarityFunction.values().length);
|
||||
}
|
||||
|
||||
public void testVectorEncodingOrdinals() {
|
||||
// make sure we don't accidentally mess up vector encoding identifiers by re-ordering their
|
||||
// enumerators
|
||||
assertEquals(0, VectorEncoding.BYTE.ordinal());
|
||||
assertEquals(1, VectorEncoding.FLOAT32.ordinal());
|
||||
assertEquals(2, VectorEncoding.values().length);
|
||||
}
|
||||
|
||||
public void testAdvance() throws Exception {
|
||||
try (Directory dir = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
@ -1091,10 +1146,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
public void testVectorValuesReportCorrectDocs() throws Exception {
|
||||
final int numDocs = atLeast(1000);
|
||||
final int dim = random().nextInt(20) + 1;
|
||||
final VectorSimilarityFunction similarityFunction =
|
||||
VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
||||
|
||||
double fieldValuesCheckSum = 0;
|
||||
int fieldDocCount = 0;
|
||||
long fieldSumDocIDs = 0;
|
||||
@ -1106,9 +1157,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
int docID = random().nextInt(numDocs);
|
||||
doc.add(new StoredField("id", docID));
|
||||
if (random().nextInt(4) == 3) {
|
||||
float[] vector = randomVector(dim);
|
||||
doc.add(new KnnVectorField("knn_vector", vector, similarityFunction));
|
||||
fieldValuesCheckSum += vector[0];
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> {
|
||||
BytesRef b = randomVector8(dim);
|
||||
fieldValuesCheckSum += b.bytes[b.offset];
|
||||
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
float[] v = randomVector(dim);
|
||||
fieldValuesCheckSum += v[0];
|
||||
doc.add(new KnnVectorField("knn_vector", v, similarityFunction));
|
||||
}
|
||||
}
|
||||
fieldDocCount++;
|
||||
fieldSumDocIDs += docID;
|
||||
}
|
||||
@ -1134,7 +1194,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||
}
|
||||
}
|
||||
}
|
||||
assertEquals(fieldValuesCheckSum, checksum, 1e-3);
|
||||
assertEquals(
|
||||
fieldValuesCheckSum,
|
||||
checksum,
|
||||
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
|
||||
assertEquals(fieldDocCount, docCount);
|
||||
assertEquals(fieldSumDocIDs, sumDocIds);
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ import org.apache.lucene.index.SortedSetDocValues;
|
||||
import org.apache.lucene.index.StoredFieldVisitor;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
@ -228,6 +229,12 @@ class MergeReaderWrapper extends LeafReader {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDocs() {
|
||||
return in.numDocs();
|
||||
|
@ -88,6 +88,7 @@ public class MismatchedLeafReader extends FilterLeafReader {
|
||||
oldInfo.getPointIndexDimensionCount(), // index dimension count
|
||||
oldInfo.getPointNumBytes(), // dimension numBytes
|
||||
oldInfo.getVectorDimension(), // number of dimensions of the field's vector
|
||||
oldInfo.getVectorEncoding(), // numeric type of vector samples
|
||||
// distance function for calculating similarity of the field's vector
|
||||
oldInfo.getVectorSimilarityFunction(),
|
||||
oldInfo.isSoftDeletesField()); // used as soft-deletes field
|
||||
|
@ -62,6 +62,7 @@ import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.TermState;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.tests.IndexPackageAccess;
|
||||
import org.apache.lucene.internal.tests.TestSecrets;
|
||||
@ -163,6 +164,7 @@ public class RandomPostingsTester {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false);
|
||||
fieldUpto++;
|
||||
@ -734,6 +736,7 @@ public class RandomPostingsTester {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
VectorEncoding.FLOAT32,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
false);
|
||||
}
|
||||
|
@ -233,6 +233,12 @@ public class QueryUtils {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectorsExhaustively(
|
||||
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FieldInfos getFieldInfos() {
|
||||
return FieldInfos.EMPTY;
|
||||
|
Loading…
x
Reference in New Issue
Block a user