Add BitVectors format and make flat vectors format easier to extend (#13288)

Instead of making a separate thing pluggable inside of the FieldFormat, this instead keeps the vector similarities as they are, but allows a custom scorer to be provided to the FlatVector storage used by HNSW. 

This idea is akin to the compression extensions we have. But in this case, its for vector scorers. 

To show how this would work in practice, I took the liberty of adding a new HnswBitVectorsFormat in the sandbox module.

A larger part of the change is a refactor of the `RandomAccessVectorValues<T>` to remove the `<T>`. Nothing actually uses that any longer, and we should instead rely on well defined classes and stop relying on casting with generics (yuck).
This commit is contained in:
Benjamin Trent 2024-04-17 13:13:51 -04:00 committed by GitHub
parent bc678ac67e
commit 3d86ff2e6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 1363 additions and 535 deletions

View File

@ -254,6 +254,8 @@ New Features
* GITHUB#13268: Add ability for UnifiedHighlighter to highlight a field based on combined matches from multiple fields.
(Mayya Sharipova, Jim Ferenczi)
* GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add
new Hnsw format for binary quantized vectors. (Ben Trent)
Improvements
---------------------

View File

@ -29,7 +29,7 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues<float[]> {
public class Word2VecModel implements RandomAccessVectorValues.Floats {
private final int dictionarySize;
private final int vectorDimension;
@ -88,7 +88,7 @@ public class Word2VecModel implements RandomAccessVectorValues<float[]> {
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public Word2VecModel copy() throws IOException {
return new Word2VecModel(
this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec);
}

View File

@ -23,6 +23,7 @@ import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
@ -44,6 +45,7 @@ public class Word2VecSynonymProvider {
VectorSimilarityFunction.DOT_PRODUCT;
private final Word2VecModel word2VecModel;
private final OnHeapHnswGraph hnswGraph;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
/**
* Word2VecSynonymProvider constructor
@ -53,7 +55,7 @@ public class Word2VecSynonymProvider {
public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
this.word2VecModel = model;
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
defaultFlatVectorScorer.getRandomVectorScorerSupplier(SIMILARITY_FUNCTION, word2VecModel);
HnswGraphBuilder builder =
HnswGraphBuilder.create(
scorerSupplier,
@ -75,7 +77,7 @@ public class Word2VecSynonymProvider {
float[] query = word2VecModel.vectorValue(term);
if (query != null) {
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(word2VecModel, SIMILARITY_FUNCTION, query);
defaultFlatVectorScorer.getRandomVectorScorer(SIMILARITY_FUNCTION, word2VecModel, query);
KnnCollector synonyms =
HnswGraphSearcher.search(
scorer,

View File

@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder {
private final Lucene90NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues<float[]> vectorValues;
private final RandomAccessVectorValues.Floats vectorValues;
private final SplittableRandom random;
private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private final RandomAccessVectorValues<float[]> buildVectors;
private final RandomAccessVectorValues.Floats buildVectors;
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene90HnswGraphBuilder(
RandomAccessVectorValues<float[]> vectors,
RandomAccessVectorValues.Floats vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
@ -104,8 +104,7 @@ public final class Lucene90HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
throws IOException {
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -231,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
float[] candidate,
float score,
Lucene90NeighborArray neighbors,
RandomAccessVectorValues<float[]> vectorValues)
RandomAccessVectorValues.Floats vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {

View File

@ -350,7 +350,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
implements RandomAccessVectorValues.Floats {
final int dimension;
final int[] ordToDoc;
@ -419,7 +419,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
public RandomAccessVectorValues<float[]> copy() {
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone());
}

View File

@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
float[] query,
int topK,
int numSeed,
RandomAccessVectorValues<float[]> vectors,
RandomAccessVectorValues.Floats vectors,
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,

View File

@ -26,6 +26,7 @@ import java.util.Map;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
@ -56,6 +57,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
@ -233,7 +235,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
defaultFlatVectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
@ -387,7 +390,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
implements RandomAccessVectorValues.Floats {
private final int dimension;
private final int size;
@ -464,7 +467,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
@Override
public RandomAccessVectorValues<float[]> copy() {
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone());
}

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
@ -55,6 +56,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
@ -232,7 +234,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
defaultFlatVectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),

View File

@ -28,7 +28,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
implements RandomAccessVectorValues.Floats {
protected final int dimension;
protected final int size;
@ -114,7 +114,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
}
@ -173,7 +173,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public OffHeapFloatVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
}
@ -240,7 +240,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public OffHeapFloatVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
@ -56,6 +57,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
@ -269,7 +271,8 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
defaultFlatVectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
@ -288,7 +291,8 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
defaultFlatVectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),

View File

@ -30,7 +30,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<byte[]> {
implements RandomAccessVectorValues.Bytes {
protected final int dimension;
protected final int size;
@ -124,7 +124,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<byte[]> copy() throws IOException {
public OffHeapByteVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -186,7 +186,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<byte[]> copy() throws IOException {
public OffHeapByteVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@ -253,7 +253,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<byte[]> copy() throws IOException {
public OffHeapByteVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -28,7 +28,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
implements RandomAccessVectorValues.Floats {
protected final int dimension;
protected final int size;
@ -120,7 +120,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public OffHeapFloatVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -182,7 +182,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public OffHeapFloatVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@ -249,7 +249,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public OffHeapFloatVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -24,8 +24,9 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
@ -65,6 +66,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
Lucene95HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
@ -300,7 +302,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
fieldEntry.vectorDataLength,
vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
defaultFlatVectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
@ -328,7 +331,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
fieldEntry.vectorDataLength,
vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
defaultFlatVectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),

View File

@ -231,7 +231,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
private void writeGraph(
IndexOutput graphData,
RandomAccessVectorValues<float[]> vectorValues,
RandomAccessVectorValues.Floats vectorValues,
VectorSimilarityFunction similarityFunction,
long graphDataOffset,
long[] offsets,

View File

@ -24,6 +24,7 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
@ -54,8 +55,9 @@ public final class Lucene91HnswGraphBuilder {
private final double ml;
private final Lucene91NeighborArray scratch;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues<float[]> vectorValues;
private final RandomAccessVectorValues.Floats vectorValues;
private final SplittableRandom random;
private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
@ -66,7 +68,7 @@ public final class Lucene91HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private RandomAccessVectorValues<float[]> buildVectors;
private RandomAccessVectorValues.Floats buildVectors;
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
@ -81,7 +83,7 @@ public final class Lucene91HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene91HnswGraphBuilder(
RandomAccessVectorValues<float[]> vectors,
RandomAccessVectorValues.Floats vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
@ -118,8 +120,7 @@ public final class Lucene91HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
throws IOException {
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -146,7 +147,7 @@ public final class Lucene91HnswGraphBuilder {
/** Inserts a doc with vector value to the graph */
void addGraphNode(int node, float[] value) throws IOException {
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, similarityFunction, value);
defaultFlatVectorScorer.getRandomVectorScorer(similarityFunction, vectorValues, value);
HnswGraphBuilder.GraphBuilderKnnCollector candidates;
final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1;
@ -253,7 +254,7 @@ public final class Lucene91HnswGraphBuilder {
float[] candidate,
float score,
Lucene91NeighborArray neighbors,
RandomAccessVectorValues<float[]> vectorValues)
RandomAccessVectorValues.Floats vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {

View File

@ -239,7 +239,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private Lucene91OnHeapHnswGraph writeGraph(
RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph

View File

@ -26,6 +26,7 @@ import java.nio.ByteOrder;
import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
@ -273,12 +274,12 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
// build graph
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(vectorValues, similarityFunction);
defaultFlatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
HnswGraphBuilder hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);

View File

@ -29,6 +29,7 @@ import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
@ -409,6 +410,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
// TODO: separate random access vector values from DocIdSetIterator?
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
OnHeapHnswGraph graph = null;
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
if (docsWithField.cardinality() != 0) {
// build graph
graph =
@ -421,8 +423,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
vectorDataInput,
byteSize);
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createBytes(
vectorValues, fieldInfo.getVectorSimilarityFunction());
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(), vectorValues);
HnswGraphBuilder hnswGraphBuilder =
HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
@ -437,8 +439,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
vectorDataInput,
byteSize);
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(
vectorValues, fieldInfo.getVectorSimilarityFunction());
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(), vectorValues);
HnswGraphBuilder hnswGraphBuilder =
HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
@ -656,15 +658,15 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RandomAccessVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes((List<byte[]>) vectors, dim));
case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats((List<float[]>) vectors, dim));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
@ -708,34 +710,4 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
RAVectorValues(List<T> vectors, int dim) {
this.vectors = vectors;
this.dim = dim;
}
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public T vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public RAVectorValues<T> copy() throws IOException {
return this;
}
}
}

View File

@ -29,6 +29,7 @@ import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
@ -436,27 +437,28 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
final RandomVectorScorerSupplier scorerSupplier;
switch (fieldInfo.getVectorEncoding()) {
case BYTE:
scorerSupplier =
RandomVectorScorerSupplier.createBytes(
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize),
fieldInfo.getVectorSimilarityFunction());
byteSize));
break;
case FLOAT32:
scorerSupplier =
RandomVectorScorerSupplier.createFloats(
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize),
fieldInfo.getVectorSimilarityFunction());
byteSize));
break;
default:
throw new IllegalArgumentException(
@ -695,15 +697,15 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes((List<byte[]>) vectors, dim));
case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats((List<float[]>) vectors, dim));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
@ -746,34 +748,4 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
RAVectorValues(List<T> vectors, int dim) {
this.vectors = vectors;
this.dim = dim;
}
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public T vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues<T> copy() throws IOException {
return this;
}
}
}

View File

@ -18,9 +18,11 @@
package org.apache.lucene.backward_codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
@ -56,12 +58,16 @@ class Lucene99RWHnswScalarQuantizationVectorsFormat
}
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat();
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state, null, rawVectorFormat.fieldsWriter(state));
state,
null,
rawVectorFormat.fieldsWriter(state),
new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()));
}
}
}

View File

@ -15,10 +15,13 @@
* limitations under the License.
*/
import org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat;
/** Lucene codecs and postings formats */
module org.apache.lucene.codecs {
requires org.apache.lucene.core;
exports org.apache.lucene.codecs.bitvectors;
exports org.apache.lucene.codecs.blockterms;
exports org.apache.lucene.codecs.blocktreeords;
exports org.apache.lucene.codecs.bloom;
@ -27,6 +30,8 @@ module org.apache.lucene.codecs {
exports org.apache.lucene.codecs.uniformsplit;
exports org.apache.lucene.codecs.uniformsplit.sharedterms;
provides org.apache.lucene.codecs.KnnVectorsFormat with
HnswBitVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.blocktreeords.BlockTreeOrdsPostingsFormat,
org.apache.lucene.codecs.bloom.BloomFilteringPostingsFormat,

View File

@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.bitvectors;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/** A bit vector scorer for scoring byte vectors. */
public class FlatBitVectorsScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
return new BitRandomVectorScorerSupplier(byteVectorValues);
}
throw new IllegalArgumentException(
"vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
throws IOException {
throw new IllegalArgumentException("bit vectors do not support float[] targets");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
return new BitRandomVectorScorer(byteVectorValues, target);
}
throw new IllegalArgumentException(
"vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}
static class BitRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final int bitDimensions;
private final byte[] query;
BitRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
this.query = query;
this.bitDimensions = vectorValues.dimension() * Byte.SIZE;
this.vectorValues = vectorValues;
}
@Override
public float score(int node) throws IOException {
return (bitDimensions - VectorUtil.xorBitCount(query, vectorValues.vectorValue(node)))
/ (float) bitDimensions;
}
@Override
public int maxOrd() {
return vectorValues.size();
}
@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return vectorValues.getAcceptOrds(acceptDocs);
}
}
static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;
public BitRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues)
throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] query = vectorValues1.vectorValue(ord);
return new BitRandomVectorScorer(vectorValues2, query);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BitRandomVectorScorerSupplier(vectorValues.copy());
}
}
@Override
public String toString() {
return "FlatBitVectorsScorer()";
}
}

View File

@ -0,0 +1,207 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.bitvectors;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
* Encodes bit vector values into an associated graph connecting the documents having values. The
* graph is used to power HNSW search. The format consists of two files, and uses {@link
* Lucene99FlatVectorsFormat} to store the actual vectors, but with a custom scorer implementation:
* For details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}.
*
* @lucene.experimental
*/
public final class HnswBitVectorsFormat extends KnnVectorsFormat {
public static final String NAME = "HnswBitVectorsFormat";
/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/
private final int maxConn;
/**
* The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
* for details.
*/
private final int beamWidth;
/** The format for storing, reading, merging vectors on disk */
private final FlatVectorsFormat flatVectorsFormat;
private final int numMergeWorkers;
private final TaskExecutor mergeExec;
/** Constructs a format using default graph construction parameters */
public HnswBitVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
}
/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
*/
public HnswBitVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
}
/**
* Constructs a format using the given graph construction parameters and scalar quantization.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge
*/
public HnswBitVectorsFormat(
int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to "
+ MAXIMUM_MAX_CONN
+ "; maxConn="
+ maxConn);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to "
+ MAXIMUM_BEAM_WIDTH
+ "; beamWidth="
+ beamWidth);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException(
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer());
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new FlatBitVectorsWriter(
new Lucene99HnswVectorsWriter(
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state),
numMergeWorkers,
mergeExec));
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
@Override
public String toString() {
return "HnswBitVectorsFormat(name=HnswBitVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
private static class FlatBitVectorsWriter extends KnnVectorsWriter {
private final KnnVectorsWriter delegate;
public FlatBitVectorsWriter(KnnVectorsWriter delegate) {
this.delegate = delegate;
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
delegate.mergeOneField(fieldInfo, mergeState);
}
@Override
public void finish() throws IOException {
delegate.finish();
}
@Override
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalArgumentException("HnswBitVectorsFormat only supports BYTE encoding");
}
return delegate.addField(fieldInfo);
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
delegate.flush(maxDoc, sortMap);
}
@Override
public void close() throws IOException {
delegate.close();
}
@Override
public long ramBytesUsed() {
return delegate.ramBytesUsed();
}
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* A simple bit-vector format that supports hamming distance and storing vectors in an HNSW graph
*/
package org.apache.lucene.codecs.bitvectors;

View File

@ -47,7 +47,6 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Reads vector values from a simple text format. All vectors are read up front and cached in RAM in
@ -282,8 +281,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
}
private static class SimpleTextFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
private static class SimpleTextFloatVectorValues extends FloatVectorValues {
private final BytesRefBuilder scratch = new BytesRefBuilder();
private final FieldEntry entry;
@ -315,11 +313,6 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return values[curOrd];
}
@Override
public RandomAccessVectorValues<float[]> copy() {
return this;
}
@Override
public int docID() {
if (curOrd == -1) {
@ -364,15 +357,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
value[i] = Float.parseFloat(floatStrings[i]);
}
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return values[targetOrd];
}
}
private static class SimpleTextByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
private static class SimpleTextByteVectorValues extends ByteVectorValues {
private final BytesRefBuilder scratch = new BytesRefBuilder();
private final FieldEntry entry;
@ -408,11 +395,6 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return binaryValue.bytes;
}
@Override
public RandomAccessVectorValues<BytesRef> copy() {
return this;
}
@Override
public int docID() {
if (curOrd == -1) {
@ -457,12 +439,6 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
value[i] = (byte) Float.parseFloat(floatStrings[i]);
}
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
binaryValue.bytes = values[curOrd];
return binaryValue;
}
}
private int readInt(IndexInput in, BytesRef field) throws IOException {

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat

View File

@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.bitvectors;
import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomVector8;
import java.io.IOException;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase;
public class TestHnswBitVectorsFormat extends BaseIndexFileFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new HnswBitVectorsFormat();
}
};
}
@Override
protected void addRandomFields(Document doc) {
doc.add(new KnnByteVectorField("v2", randomVector8(30), VectorSimilarityFunction.DOT_PRODUCT));
}
public void testFloatVectorFails() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
e.getMessage().contains("HnswBitVectorsFormat only supports BYTE encoding");
}
}
public void testIndexAndSearchBitVectors() throws IOException {
byte[][] vectors =
new byte[][] {
new byte[] {(byte) 0b10101110, (byte) 0b01010111},
new byte[] {(byte) 0b11110000, (byte) 0b00001111},
new byte[] {(byte) 0b11001100, (byte) 0b00110011},
new byte[] {(byte) 0b11111111, (byte) 0b00000000},
new byte[] {(byte) 0b00000000, (byte) 0b00000000}
};
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
int id = 0;
for (byte[] vector : vectors) {
Document doc = new Document();
doc.add(new KnnByteVectorField("v1", vector, VectorSimilarityFunction.DOT_PRODUCT));
doc.add(new StringField("id", Integer.toString(id++), Field.Store.YES));
w.addDocument(doc);
}
w.commit();
w.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE);
r.searchNearestVectors("v1", vectors[0], collector, null);
TopDocs topDocs = collector.topDocs();
assertEquals(3, topDocs.scoreDocs.length);
StoredFields fields = r.storedFields();
assertEquals("0", fields.document(topDocs.scoreDocs[0].doc).get("id"));
assertEquals(1.0, topDocs.scoreDocs[0].score, 1e-12);
assertEquals("2", fields.document(topDocs.scoreDocs[1].doc).get("id"));
assertEquals(0.625, topDocs.scoreDocs[1].score, 1e-12);
assertEquals("1", fields.document(topDocs.scoreDocs[2].doc).get("id"));
assertEquals(0.5, topDocs.scoreDocs[2].score, 1e-12);
}
}
}
public void testToString() {
FilterCodec customCodec =
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new HnswBitVectorsFormat(10, 20);
}
};
String expectedString =
"HnswBitVectorsFormat(name=HnswBitVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=FlatBitVectorsScorer()))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
}
public void testLimits() {
expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(-1, 20));
expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(0, 20));
expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(20, 0));
expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(20, -1));
expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(512 + 1, 20));
expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(20, 3201));
}
}

View File

@ -63,6 +63,7 @@ module org.apache.lucene.core {
org.apache.lucene.test_framework;
exports org.apache.lucene.util.quantization;
exports org.apache.lucene.codecs.hnsw;
provides org.apache.lucene.analysis.TokenizerFactory with
org.apache.lucene.analysis.standard.StandardTokenizerFactory;

View File

@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/**
* Default implementation of {@link FlatVectorsScorer}.
*
* @lucene.experimental
*/
public class DefaultFlatVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof RandomAccessVectorValues.Floats floatVectorValues) {
return new FloatScoringSupplier(floatVectorValues, similarityFunction);
} else if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
return new ByteScoringSupplier(byteVectorValues, similarityFunction);
}
throw new IllegalArgumentException(
"vectorValues must be an instance of RandomAccessVectorValues.Floats or RandomAccessVectorValues.Bytes");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Floats;
if (target.length != vectorValues.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ target.length
+ " differs from field dimension: "
+ vectorValues.dimension());
}
return new FloatVectorScorer(
(RandomAccessVectorValues.Floats) vectorValues, target, similarityFunction);
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
if (target.length != vectorValues.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ target.length
+ " differs from field dimension: "
+ vectorValues.dimension());
}
return new ByteVectorScorer(
(RandomAccessVectorValues.Bytes) vectorValues, target, similarityFunction);
}
@Override
public String toString() {
return "DefaultFlatVectorScorer()";
}
/** RandomVectorScorerSupplier for bytes vector */
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues.Bytes vectors;
private final RandomAccessVectorValues.Bytes vectors1;
private final RandomAccessVectorValues.Bytes vectors2;
private final VectorSimilarityFunction similarityFunction;
private ByteScoringSupplier(
RandomAccessVectorValues.Bytes vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return new ByteVectorScorer(vectors2, vectors1.vectorValue(ord), similarityFunction);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ByteScoringSupplier(vectors, similarityFunction);
}
}
/** RandomVectorScorerSupplier for Float vector */
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues.Floats vectors;
private final RandomAccessVectorValues.Floats vectors1;
private final RandomAccessVectorValues.Floats vectors2;
private final VectorSimilarityFunction similarityFunction;
private FloatScoringSupplier(
RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return new FloatVectorScorer(vectors2, vectors1.vectorValue(ord), similarityFunction);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new FloatScoringSupplier(vectors, similarityFunction);
}
}
/** A {@link RandomVectorScorer} for float vectors. */
private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final RandomAccessVectorValues.Floats values;
private final float[] query;
private final VectorSimilarityFunction similarityFunction;
public FloatVectorScorer(
RandomAccessVectorValues.Floats values,
float[] query,
VectorSimilarityFunction similarityFunction) {
super(values);
this.values = values;
this.query = query;
this.similarityFunction = similarityFunction;
}
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, values.vectorValue(node));
}
}
/** A {@link RandomVectorScorer} for byte vectors. */
private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final RandomAccessVectorValues.Bytes values;
private final byte[] query;
private final VectorSimilarityFunction similarityFunction;
public ByteVectorScorer(
RandomAccessVectorValues.Bytes values,
byte[] query,
VectorSimilarityFunction similarityFunction) {
super(values);
this.values = values;
this.query = query;
this.similarityFunction = similarityFunction;
}
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, values.vectorValue(node));
}
}
}

View File

@ -15,7 +15,9 @@
* limitations under the License.
*/
package org.apache.lucene.codecs;
package org.apache.lucene.codecs.hnsw;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
/**
* Vectors' writer for a field

View File

@ -15,14 +15,15 @@
* limitations under the License.
*/
package org.apache.lucene.codecs;
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
/**
* Encodes/decodes per-document vectors
* Encodes/decodes per-document vectors and provides a scoring interface for the flat stored vectors
*
* @lucene.experimental
*/

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.lucene.codecs;
package org.apache.lucene.codecs.hnsw;
import java.io.Closeable;
import java.io.IOException;
@ -41,8 +41,20 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
*/
public abstract class FlatVectorsReader implements Closeable, Accountable {
/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorScorer;
/** Sole constructor */
protected FlatVectorsReader() {}
protected FlatVectorsReader(FlatVectorsScorer vectorsScorer) {
this.vectorScorer = vectorsScorer;
}
/**
* @return the {@link FlatVectorsScorer} for this reader.
*/
public FlatVectorsScorer getFlatVectorScorer() {
return vectorScorer;
}
/**
* Returns a {@link RandomVectorScorer} for the given field and target vector.

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/**
* Provides mechanisms to score vectors that are stored in a flat file The purpose of this class is
* for providing flexibility to the codec utilizing the vectors
*
* @lucene.experimental
*/
public interface FlatVectorsScorer {
/**
* Returns a {@link RandomVectorScorerSupplier} that can be used to score vectors
*
* @param similarityFunction the similarity function to use
* @param vectorValues the vector values to score
* @return a {@link RandomVectorScorerSupplier} that can be used to score vectors
* @throws IOException if an I/O error occurs
*/
RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
throws IOException;
/**
* Returns a {@link RandomVectorScorer} for the given set of vectors and target vector.
*
* @param similarityFunction the similarity function to use
* @param vectorValues the vector values to score
* @param target the target vector
* @return a {@link RandomVectorScorer} for the given field and target vector.
* @throws IOException if an I/O error occurs when reading from the index.
*/
RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
throws IOException;
/**
* Returns a {@link RandomVectorScorer} for the given set of vectors and target vector.
*
* @param similarityFunction the similarity function to use
* @param vectorValues the vector values to score
* @param target the target vector
* @return a {@link RandomVectorScorer} for the given field and target vector.
* @throws IOException if an I/O error occurs when reading from the index.
*/
RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
throws IOException;
}

View File

@ -15,10 +15,11 @@
* limitations under the License.
*/
package org.apache.lucene.codecs;
package org.apache.lucene.codecs.hnsw;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
@ -32,9 +33,20 @@ import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
* @lucene.experimental
*/
public abstract class FlatVectorsWriter implements Accountable, Closeable {
/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorsScorer;
/** Sole constructor */
protected FlatVectorsWriter() {}
protected FlatVectorsWriter(FlatVectorsScorer vectorsScorer) {
this.vectorsScorer = vectorsScorer;
}
/**
* @return the {@link FlatVectorsScorer} for this reader.
*/
public FlatVectorsScorer getFlatVectorScorer() {
return vectorsScorer;
}
/**
* Add a new field for indexing, allowing the user to provide a writer that the flat vectors

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.util.hnsw.HnswGraph;

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
* Default scalar quantized implementation of {@link FlatVectorsScorer}.
*
* @lucene.experimental
*/
public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private final FlatVectorsScorer nonQuantizedDelegate;
public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
nonQuantizedDelegate = flatVectorsScorer;
}
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
return new ScalarQuantizedRandomVectorScorerSupplier(
similarityFunction,
quantizedByteVectorValues.getScalarQuantizer(),
quantizedByteVectorValues);
}
// It is possible to get to this branch during initial indexing and flush
return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
throws IOException {
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
byte[] targetBytes = new byte[target.length];
float offsetCorrection =
ScalarQuantizedRandomVectorScorer.quantizeQuery(
target, targetBytes, similarityFunction, scalarQuantizer);
ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
return new ScalarQuantizedRandomVectorScorer(
scalarQuantizedVectorSimilarity,
quantizedByteVectorValues,
targetBytes,
offsetCorrection);
}
// It is possible to get to this branch during initial indexing and flush
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
throws IOException {
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
@Override
public String toString() {
return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')';
}
}

View File

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* HNSW vector helper classes. The classes in this package provide a scoring and storing mechanism
* for vectors stored in a flat file. This allows for HNSW formats to be extended with other flat
* storage formats or scoring without significant changes to the HNSW code. Some examples for
* scoring include {@link org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer} and {@link
* org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer}. Some examples for storing include {@link
* org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat} and {@link
* org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat}.
*/
package org.apache.lucene.codecs.hnsw;

View File

@ -30,7 +30,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
public abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<byte[]> {
implements RandomAccessVectorValues.Bytes {
protected final int dimension;
protected final int size;
@ -85,12 +85,11 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
return new EmptyOffHeapVectorValues(dimension);
}
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
int byteSize = dimension;
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, dimension);
} else {
return new SparseOffHeapVectorValues(
configuration, vectorData, bytesSlice, dimension, byteSize);
configuration, vectorData, bytesSlice, dimension, dimension);
}
}
@ -131,7 +130,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<byte[]> copy() throws IOException {
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -194,7 +193,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<byte[]> copy() throws IOException {
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
configuration, dataIn, slice.clone(), dimension, byteSize);
}
@ -262,7 +261,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public RandomAccessVectorValues<byte[]> copy() throws IOException {
public EmptyOffHeapVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -29,7 +29,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
public abstract class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
implements RandomAccessVectorValues.Floats {
protected final int dimension;
protected final int size;
@ -125,7 +125,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -188,7 +188,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
configuration, dataIn, slice.clone(), dimension, byteSize);
}
@ -256,7 +256,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
public EmptyOffHeapVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -18,9 +18,10 @@
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
@ -75,24 +76,25 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
public static final int VERSION_CURRENT = VERSION_START;
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
private final FlatVectorsScorer vectorsScorer;
/** Constructs a format */
public Lucene99FlatVectorsFormat() {
super();
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
this.vectorsScorer = vectorsScorer;
}
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state);
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99FlatVectorsReader(state);
return new Lucene99FlatVectorsReader(state, vectorsScorer);
}
@Override
public String toString() {
return "Lucene99FlatVectorsFormat()";
return "Lucene99FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')';
}
}

View File

@ -24,7 +24,8 @@ import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
@ -59,7 +60,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
public Lucene99FlatVectorsReader(SegmentReadState state) throws IOException {
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
throws IOException {
super(scorer);
int versionMeta = readMetadata(state);
boolean success = false;
try {
@ -217,7 +220,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return RandomVectorScorer.createFloats(
return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction,
OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
@ -225,7 +229,6 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData),
fieldEntry.similarityFunction,
target);
}
@ -235,7 +238,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return null;
}
return RandomVectorScorer.createBytes(
return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction,
OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
@ -243,7 +247,6 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData),
fieldEntry.similarityFunction,
target);
}

View File

@ -27,10 +27,11 @@ import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
@ -71,7 +72,9 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
public Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException {
public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer)
throws IOException {
super(scorer);
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
@ -305,20 +308,20 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
final IndexInput finalVectorDataInput = vectorDataInput;
final RandomVectorScorerSupplier randomVectorScorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
case BYTE -> vectorsScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
fieldInfo.getVectorDimension() * Byte.BYTES),
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
fieldInfo.getVectorDimension() * Byte.BYTES));
case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
fieldInfo.getVectorDimension() * Float.BYTES),
fieldInfo.getVectorSimilarityFunction());
fieldInfo.getVectorDimension() * Float.BYTES));
};
return new FlatCloseableRandomVectorScorerSupplier(
() -> {

View File

@ -25,10 +25,10 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMU
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;

View File

@ -19,10 +19,11 @@ package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.MergeScheduler;
@ -101,7 +102,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
* <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large
* numbers here will use an inordinate amount of heap
*/
static final int MAXIMUM_MAX_CONN = 512;
public static final int MAXIMUM_MAX_CONN = 512;
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
@ -111,7 +112,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
* maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 =
* 3200`
*/
static final int MAXIMUM_BEAM_WIDTH = 3200;
public static final int MAXIMUM_BEAM_WIDTH = 3200;
/**
* Default number of the size of the queue maintained while searching during a graph construction.
@ -137,7 +138,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
private final int beamWidth;
/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat();
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

View File

@ -25,9 +25,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;

View File

@ -25,9 +25,10 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
@ -127,7 +128,12 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
@Override
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter<?> newField =
FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
FieldWriter.create(
flatVectorWriter.getFlatVectorScorer(),
fieldInfo,
M,
beamWidth,
segmentWriteState.infoStream);
fields.add(newField);
return flatVectorWriter.addField(fieldInfo, newField);
}
@ -542,29 +548,32 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private int lastDocID = -1;
private int node = 0;
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
static FieldWriter<?> create(
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream);
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream);
case BYTE -> new FieldWriter<byte[]>(scorer, fieldInfo, M, beamWidth, infoStream);
case FLOAT32 -> new FieldWriter<float[]>(scorer, fieldInfo, M, beamWidth, infoStream);
};
}
@SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FieldWriter(
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, fieldInfo.getVectorDimension());
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case BYTE -> scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes(
(List<byte[]>) vectors, fieldInfo.getVectorDimension()));
case FLOAT32 -> scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats(
(List<float[]>) vectors, fieldInfo.getVectorDimension()));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
@ -609,34 +618,4 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
RAVectorValues(List<T> vectors, int dim) {
this.vectors = vectors;
this.dim = dim;
}
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public T vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues<T> copy() throws IOException {
return this;
}
}
}

View File

@ -18,9 +18,11 @@
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
@ -46,7 +48,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String META_EXTENSION = "vemq";
static final String VECTOR_DATA_EXTENSION = "veq";
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat();
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
@ -62,6 +65,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
final byte bits;
final boolean compress;
final ScalarQuantizedVectorScorer flatVectorScorer;
/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
@ -98,6 +102,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer = new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
}
public static float calculateDefaultConfidenceInterval(int vectorDimension) {
@ -115,6 +120,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
+ bits
+ ", compress="
+ compress
+ ", flatVectorScorer="
+ flatVectorScorer
+ ", rawVectorFormat="
+ rawVectorFormat
+ ")";
@ -123,11 +130,17 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state, confidenceInterval, bits, compress, rawVectorFormat.fieldsWriter(state));
state,
confidenceInterval,
bits,
compress,
rawVectorFormat.fieldsWriter(state),
flatVectorScorer);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state));
return new Lucene99ScalarQuantizedVectorsReader(
state, rawVectorFormat.fieldsReader(state), flatVectorScorer);
}
}

View File

@ -24,7 +24,8 @@ import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
@ -45,7 +46,6 @@ import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
@ -64,7 +64,9 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
private final FlatVectorsReader rawVectorsReader;
public Lucene99ScalarQuantizedVectorsReader(
SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
throws IOException {
super(scorer);
this.rawVectorsReader = rawVectorsReader;
int versionMeta = -1;
String metaFileName =
@ -224,13 +226,12 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.bits,
fieldEntry.scalarQuantizer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
return new ScalarQuantizedRandomVectorScorer(
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
return vectorScorer.getRandomVectorScorer(fieldEntry.similarityFunction, vectorValues, target);
}
@Override
@ -280,7 +281,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.bits,
fieldEntry.scalarQuantizer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,

View File

@ -30,11 +30,12 @@ import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.DocIDMerger;
@ -59,7 +60,6 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
@ -102,7 +102,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private boolean finished;
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate)
SegmentWriteState state,
Float confidenceInterval,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
throws IOException {
this(
state,
@ -110,7 +113,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
confidenceInterval,
(byte) 7,
false,
rawVectorDelegate);
rawVectorDelegate,
scorer);
}
public Lucene99ScalarQuantizedVectorsWriter(
@ -118,7 +122,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
Float confidenceInterval,
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate)
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
throws IOException {
this(
state,
@ -126,7 +131,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
confidenceInterval,
bits,
compress,
rawVectorDelegate);
rawVectorDelegate,
scorer);
}
private Lucene99ScalarQuantizedVectorsWriter(
@ -135,8 +141,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
Float confidenceInterval,
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate)
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
throws IOException {
super(scorer);
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.compress = compress;
@ -511,13 +519,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
},
docsWithField.cardinality(),
new ScalarQuantizedRandomVectorScorerSupplier(
vectorsScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
mergedQuantizationState,
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
bits,
mergedQuantizationState,
compress,
quantizationDataInput)));
} finally {
@ -1091,12 +1098,12 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
static final class ScalarQuantizedCloseableRandomVectorScorerSupplier
implements CloseableRandomVectorScorerSupplier {
private final ScalarQuantizedRandomVectorScorerSupplier supplier;
private final RandomVectorScorerSupplier supplier;
private final Closeable onClose;
private final int numVectors;
ScalarQuantizedCloseableRandomVectorScorerSupplier(
Closeable onClose, int numVectors, ScalarQuantizedRandomVectorScorerSupplier supplier) {
Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) {
this.onClose = onClose;
this.supplier = supplier;
this.numVectors = numVectors;

View File

@ -26,6 +26,7 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
* Read the quantized vector values and their score correction values from the index input. This
@ -37,7 +38,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
protected final int dimension;
protected final int size;
protected final int numBytes;
protected final byte bits;
protected final ScalarQuantizer scalarQuantizer;
protected final boolean compress;
protected final IndexInput slice;
@ -81,13 +82,17 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
}
OffHeapQuantizedByteVectorValues(
int dimension, int size, byte bits, boolean compress, IndexInput slice) {
int dimension,
int size,
ScalarQuantizer scalarQuantizer,
boolean compress,
IndexInput slice) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.bits = bits;
this.scalarQuantizer = scalarQuantizer;
this.compress = compress;
if (bits <= 4 && compress) {
if (scalarQuantizer.getBits() <= 4 && compress) {
this.numBytes = (dimension + 1) >> 1;
} else {
this.numBytes = dimension;
@ -97,6 +102,11 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
binaryValue = byteBuffer.array();
}
@Override
public ScalarQuantizer getScalarQuantizer() {
return scalarQuantizer;
}
@Override
public int dimension() {
return dimension;
@ -129,7 +139,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
byte bits,
ScalarQuantizer scalarQuantizer,
boolean compress,
long quantizedVectorDataOffset,
long quantizedVectorDataLength,
@ -142,10 +152,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
vectorData.slice(
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, size, bits, compress, bytesSlice);
return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, bytesSlice);
} else {
return new SparseOffHeapVectorValues(
configuration, dimension, size, bits, compress, vectorData, bytesSlice);
configuration, dimension, size, scalarQuantizer, compress, vectorData, bytesSlice);
}
}
@ -158,8 +168,12 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension, int size, byte bits, boolean compress, IndexInput slice) {
super(dimension, size, bits, compress, slice);
int dimension,
int size,
ScalarQuantizer scalarQuantizer,
boolean compress,
IndexInput slice) {
super(dimension, size, scalarQuantizer, compress, slice);
}
@Override
@ -188,7 +202,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, bits, compress, slice.clone());
return new DenseOffHeapVectorValues(
dimension, size, scalarQuantizer, compress, slice.clone());
}
@Override
@ -208,12 +223,12 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
byte bits,
ScalarQuantizer scalarQuantizer,
boolean compress,
IndexInput dataIn,
IndexInput slice)
throws IOException {
super(dimension, size, bits, compress, slice);
super(dimension, size, scalarQuantizer, compress, slice);
this.configuration = configuration;
this.dataIn = dataIn;
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
@ -244,7 +259,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
configuration, dimension, size, bits, compress, dataIn, slice.clone());
configuration, dimension, size, scalarQuantizer, compress, dataIn, slice.clone());
}
@Override
@ -274,7 +289,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, (byte) 7, false, null);
super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), false, null);
}
private int doc = -1;

View File

@ -182,6 +182,30 @@ public final class VectorUtil {
return IMPL.int4DotProduct(a, b);
}
/**
* XOR bit count computed over signed bytes.
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the XOR bit count of the two vectors
*/
public static int xorBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
int distance = 0, i = 0;
for (final int upperBound = a.length & ~(Long.BYTES - 1); i < upperBound; i += Long.BYTES) {
distance +=
Long.bitCount(
(long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF);
}
return distance;
}
/**
* Dot product score computed over signed bytes, scaled to be in [0, 1].
*

View File

@ -17,7 +17,7 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;

View File

@ -19,8 +19,8 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;

View File

@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.util.Bits;
/**
@ -26,7 +27,7 @@ import org.apache.lucene.util.Bits;
*
* @lucene.experimental
*/
public interface RandomAccessVectorValues<T> {
public interface RandomAccessVectorValues {
/** Return the number of vector values */
int size();
@ -34,19 +35,11 @@ public interface RandomAccessVectorValues<T> {
/** Return the dimension of the returned vector values */
int dimension();
/**
* Return the vector value indexed at the given ordinal.
*
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
T vectorValue(int targetOrd) throws IOException;
/**
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
* access different values at once, to avoid overwriting the underlying float vector returned by
* {@link RandomAccessVectorValues#vectorValue}.
* access different values at once, to avoid overwriting the underlying vector returned.
*/
RandomAccessVectorValues<T> copy() throws IOException;
RandomAccessVectorValues copy() throws IOException;
/**
* Translates vector ordinal to the correct document ID. By default, this is an identity function.
@ -67,4 +60,92 @@ public interface RandomAccessVectorValues<T> {
default Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
/** Float vector values. */
interface Floats extends RandomAccessVectorValues {
@Override
RandomAccessVectorValues.Floats copy() throws IOException;
/**
* Return the vector value indexed at the given ordinal.
*
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
float[] vectorValue(int targetOrd) throws IOException;
}
/** Byte vector values. */
interface Bytes extends RandomAccessVectorValues {
@Override
RandomAccessVectorValues.Bytes copy() throws IOException;
/**
* Return the vector value indexed at the given ordinal.
*
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
byte[] vectorValue(int targetOrd) throws IOException;
}
/**
* Creates a {@link RandomAccessVectorValues.Floats} from a list of float arrays.
*
* @param vectors the list of float arrays
* @param dim the dimension of the vectors
* @return a {@link RandomAccessVectorValues.Floats} instance
*/
static RandomAccessVectorValues.Floats fromFloats(List<float[]> vectors, int dim) {
return new RandomAccessVectorValues.Floats() {
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues.Floats copy() throws IOException {
return this;
}
};
}
/**
* Creates a {@link RandomAccessVectorValues.Bytes} from a list of byte arrays.
*
* @param vectors the list of byte arrays
* @param dim the dimension of the vectors
* @return a {@link RandomAccessVectorValues.Bytes} instance
*/
static RandomAccessVectorValues.Bytes fromBytes(List<byte[]> vectors, int dim) {
return new RandomAccessVectorValues.Bytes() {
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public byte[] vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues.Bytes copy() throws IOException {
return this;
}
};
}
}

View File

@ -18,7 +18,6 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
@ -56,82 +55,16 @@ public interface RandomVectorScorer {
return acceptDocs;
}
/**
* Creates a default scorer for float vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
* @param query the actual query
*/
static RandomVectorScorer createFloats(
final RandomAccessVectorValues<float[]> vectors,
final VectorSimilarityFunction similarityFunction,
final float[] query) {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
return new AbstractRandomVectorScorer<>(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.vectorValue(node));
}
};
}
/**
* Creates a default scorer for byte vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to use to score vectors
* @param query the actual query
*/
static RandomVectorScorer createBytes(
final RandomAccessVectorValues<byte[]> vectors,
final VectorSimilarityFunction similarityFunction,
final byte[] query) {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
return new AbstractRandomVectorScorer<>(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.vectorValue(node));
}
};
}
/**
* Creates a default scorer for random access vectors.
*
* @param <T> the type of the vector values
*/
abstract class AbstractRandomVectorScorer<T> implements RandomVectorScorer {
private final RandomAccessVectorValues<T> values;
/** Creates a default scorer for random access vectors. */
abstract class AbstractRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues values;
/**
* Creates a new scorer for the given vector values.
*
* @param values the vector values
*/
public AbstractRandomVectorScorer(RandomAccessVectorValues<T> values) {
public AbstractRandomVectorScorer(RandomAccessVectorValues values) {
this.values = values;
}

View File

@ -18,7 +18,6 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
/** A supplier that creates {@link RandomVectorScorer} from an ordinal. */
public interface RandomVectorScorerSupplier {
@ -36,100 +35,4 @@ public interface RandomVectorScorerSupplier {
* be used in other threads.
*/
RandomVectorScorerSupplier copy() throws IOException;
/**
* Creates a {@link RandomVectorScorerSupplier} to compare float vectors. The vectorValues passed
* in will be copied and the original copy will not be used.
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
*/
static RandomVectorScorerSupplier createFloats(
final RandomAccessVectorValues<float[]> vectors,
final VectorSimilarityFunction similarityFunction)
throws IOException {
// We copy the provided random accessor just once during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
return new FloatScoringSupplier(vectors, similarityFunction);
}
/**
* Creates a {@link RandomVectorScorerSupplier} to compare byte vectors. The vectorValues passed
* in will be copied and the original copy will not be used.
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
*/
static RandomVectorScorerSupplier createBytes(
final RandomAccessVectorValues<byte[]> vectors,
final VectorSimilarityFunction similarityFunction)
throws IOException {
// We copy the provided random accessor only during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
return new ByteScoringSupplier(vectors, similarityFunction);
}
/** RandomVectorScorerSupplier for bytes vector */
final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues<byte[]> vectors;
private final RandomAccessVectorValues<byte[]> vectors1;
private final RandomAccessVectorValues<byte[]> vectors2;
private final VectorSimilarityFunction similarityFunction;
private ByteScoringSupplier(
RandomAccessVectorValues<byte[]> vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) {
@Override
public float score(int cand) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
}
};
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ByteScoringSupplier(vectors, similarityFunction);
}
}
/** RandomVectorScorerSupplier for Float vector */
final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues<float[]> vectors;
private final RandomAccessVectorValues<float[]> vectors1;
private final RandomAccessVectorValues<float[]> vectors2;
private final VectorSimilarityFunction similarityFunction;
private FloatScoringSupplier(
RandomAccessVectorValues<float[]> vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) {
@Override
public float score(int cand) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
}
};
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new FloatScoringSupplier(vectors, similarityFunction);
}
}
}

View File

@ -25,7 +25,10 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
*
* @lucene.experimental
*/
public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues<byte[]> {
public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues.Bytes {
ScalarQuantizer getScalarQuantizer();
float getScoreCorrectionConstant();
@Override

View File

@ -28,7 +28,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
* @lucene.experimental
*/
public class ScalarQuantizedRandomVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer<byte[]> {
extends RandomVectorScorer.AbstractRandomVectorScorer {
public static float quantizeQuery(
float[] query,
@ -64,22 +64,6 @@ public class ScalarQuantizedRandomVectorScorer
this.values = values;
}
public ScalarQuantizedRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values,
float[] query) {
super(values);
byte[] quantizedQuery = new byte[query.length];
float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer);
this.quantizedQuery = quantizedQuery;
this.queryOffset = correction;
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
this.values = values;
}
@Override
public float score(int node) throws IOException {
byte[] storedVectorValue = values.vectorValue(node);

View File

@ -30,6 +30,7 @@ public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSc
private final RandomAccessQuantizedByteVectorValues values;
private final ScalarQuantizedVectorSimilarity similarity;
private final VectorSimilarityFunction vectorSimilarityFunction;
public ScalarQuantizedRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
@ -39,12 +40,16 @@ public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSc
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
this.values = values;
this.vectorSimilarityFunction = similarityFunction;
}
private ScalarQuantizedRandomVectorScorerSupplier(
ScalarQuantizedVectorSimilarity similarity, RandomAccessQuantizedByteVectorValues values) {
ScalarQuantizedVectorSimilarity similarity,
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessQuantizedByteVectorValues values) {
this.similarity = similarity;
this.values = values;
this.vectorSimilarityFunction = vectorSimilarityFunction;
}
@Override
@ -57,6 +62,7 @@ public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSc
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ScalarQuantizedRandomVectorScorerSupplier(similarity, values.copy());
return new ScalarQuantizedRandomVectorScorerSupplier(
similarity, vectorSimilarityFunction, values.copy());
}
}

View File

@ -182,7 +182,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
}
};
String expectedString =
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, rawVectorFormat=Lucene99FlatVectorsFormat()))";
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
}

View File

@ -38,7 +38,7 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
}
};
String expectedString =
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat())";
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
}

View File

@ -42,6 +42,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
@ -1031,8 +1032,8 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
try (Directory directory = newDirectory()) {
MockAnalyzer mockAnalyzer = new MockAnalyzer(random());
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
KnnVectorsFormat format1 = randomVectorFormat();
KnnVectorsFormat format2 = randomVectorFormat();
KnnVectorsFormat format1 = randomVectorFormat(VectorEncoding.FLOAT32);
KnnVectorsFormat format2 = randomVectorFormat(VectorEncoding.FLOAT32);
iwc.setCodec(
new AssertingCodec() {
@Override

View File

@ -22,7 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.util.BytesRef;
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues<T> {
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues {
protected final int dimension;
protected final T[] denseValues;
@ -52,7 +52,6 @@ abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues<T
return dimension;
}
@Override
public T vectorValue(int targetOrd) {
return denseValues[targetOrd];
}

View File

@ -40,6 +40,7 @@ import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
@ -87,6 +88,7 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
abstract class HnswGraphTestCase<T> extends LuceneTestCase {
VectorSimilarityFunction similarityFunction;
DefaultFlatVectorScorer flatVectorScorer = new DefaultFlatVectorScorer();
abstract VectorEncoding getVectorEncoding();
@ -109,30 +111,23 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
abstract RandomAccessVectorValues circularVectorValues(int nDoc);
abstract T getTargetVector();
@SuppressWarnings("unchecked")
protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues<T> vectors)
protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors)
throws IOException {
return switch (getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) vectors, similarityFunction);
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) vectors, similarityFunction);
};
return flatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectors);
}
@SuppressWarnings("unchecked")
protected RandomVectorScorer buildScorer(RandomAccessVectorValues<T> vectors, T query)
protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query)
throws IOException {
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
RandomAccessVectorValues vectorsCopy = vectors.copy();
return switch (getVectorEncoding()) {
case BYTE -> RandomVectorScorer.createBytes(
(RandomAccessVectorValues<byte[]>) vectorsCopy, similarityFunction, (byte[]) query);
case FLOAT32 -> RandomVectorScorer.createFloats(
(RandomAccessVectorValues<float[]>) vectorsCopy, similarityFunction, (float[]) query);
case BYTE -> flatVectorScorer.getRandomVectorScorer(
similarityFunction, vectorsCopy, (byte[]) query);
case FLOAT32 -> flatVectorScorer.getRandomVectorScorer(
similarityFunction, vectorsCopy, (float[]) query);
};
}
@ -464,7 +459,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testAknnDiverse() throws IOException {
int nDoc = 100;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
@ -496,7 +491,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
@SuppressWarnings("unchecked")
public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
@ -521,7 +516,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
@SuppressWarnings("unchecked")
public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
@ -714,7 +709,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testVisitedLimit() throws IOException {
int nDoc = 500;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
@ -749,7 +744,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int M = randomIntBetween(4, 96);
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
RandomAccessVectorValues<T> vectors = vectorValues(size, dim);
RandomAccessVectorValues vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder =
@ -1078,7 +1073,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues<float[]> {
implements RandomAccessVectorValues.Floats {
private final int size;
private final float[] value;
@ -1137,7 +1132,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<byte[]> {
implements RandomAccessVectorValues.Bytes {
private final int size;
private final float[] value;
private final byte[] bValue;

View File

@ -20,7 +20,8 @@ package org.apache.lucene.util.hnsw;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
class MockByteVectorValues extends AbstractMockVectorValues<byte[]> {
class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
implements RandomAccessVectorValues.Bytes {
private final byte[] scratch;
static MockByteVectorValues fromValues(byte[][] values) {
@ -55,6 +56,11 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]> {
numVectors);
}
@Override
public byte[] vectorValue(int ord) {
return values[ord];
}
@Override
public byte[] vectorValue() {
if (LuceneTestCase.random().nextBoolean()) {

View File

@ -20,7 +20,8 @@ package org.apache.lucene.util.hnsw;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
class MockVectorValues extends AbstractMockVectorValues<float[]> {
class MockVectorValues extends AbstractMockVectorValues<float[]>
implements RandomAccessVectorValues.Floats {
private final float[] scratch;
static MockVectorValues fromValues(float[][] values) {

View File

@ -132,7 +132,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
}
@Override
RandomAccessVectorValues<byte[]> circularVectorValues(int nDoc) {
CircularByteVectorValues circularVectorValues(int nDoc) {
return new CircularByteVectorValues(nDoc);
}

View File

@ -117,7 +117,7 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
}
@Override
RandomAccessVectorValues<float[]> circularVectorValues(int nDoc) {
CircularFloatVectorValues circularVectorValues(int nDoc) {
return new CircularFloatVectorValues(nDoc);
}
@ -129,7 +129,7 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
RandomAccessVectorValues<float[]> vectors = circularVectorValues(nDoc);
RandomAccessVectorValues.Floats vectors = circularVectorValues(nDoc);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());

View File

@ -18,11 +18,11 @@
package org.apache.lucene.tests.codecs.asserting;
import java.io.IOException;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;

View File

@ -101,7 +101,7 @@ import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.Version;
/** Common tests to all index formats. */
abstract class BaseIndexFileFormatTestCase extends LuceneTestCase {
public abstract class BaseIndexFileFormatTestCase extends LuceneTestCase {
private static final IndexWriterAccess INDEX_WRITER_ACCESS = TestSecrets.getIndexWriterAccess();

View File

@ -1244,7 +1244,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
iw.updateDocument(idTerm, doc);
}
protected float[] randomVector(int dim) {
public static float[] randomVector(int dim) {
assert dim > 0;
float[] v = new float[dim];
double squareSum = 0.0;
@ -1259,13 +1259,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
return v;
}
protected float[] randomNormalizedVector(int dim) {
public static float[] randomNormalizedVector(int dim) {
float[] v = randomVector(dim);
VectorUtil.l2normalize(v);
return v;
}
protected byte[] randomVector8(int dim) {
public static byte[] randomVector8(int dim) {
assert dim > 0;
float[] v = randomNormalizedVector(dim);
byte[] b = new byte[dim];

View File

@ -102,6 +102,7 @@ import java.util.regex.Pattern;
import junit.framework.AssertionFailedError;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
@ -152,6 +153,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.TermsEnum.SeekStatus;
import org.apache.lucene.index.TieredMergePolicy;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.internal.tests.IndexPackageAccess;
import org.apache.lucene.internal.tests.TestSecrets;
import org.apache.lucene.search.DocIdSetIterator;
@ -3213,11 +3215,17 @@ public abstract class LuceneTestCase extends Assert {
return it;
}
protected KnnVectorsFormat randomVectorFormat() {
protected KnnVectorsFormat randomVectorFormat(VectorEncoding vectorEncoding) {
ServiceLoader<KnnVectorsFormat> formats = java.util.ServiceLoader.load(KnnVectorsFormat.class);
List<KnnVectorsFormat> availableFormats = new ArrayList<>();
for (KnnVectorsFormat f : formats) {
availableFormats.add(f);
if (f.getName().equals(HnswBitVectorsFormat.NAME)) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
availableFormats.add(f);
}
} else {
availableFormats.add(f);
}
}
return RandomPicks.randomFrom(random(), availableFormats);
}

View File

@ -23,7 +23,7 @@ import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.FilterLeafReader;