mirror of https://github.com/apache/lucene.git
Add new VectorScorer interface to vector value iterators (#13181)
With quantized vectors, and with current vectors, we separate out the "scoring" vs. "iteration", requiring the user to always iterate the raw vectors and provide their own similarity function. While this is flexible, it creates frustration in: - Just iterating and scoring, especially since the field already has a similarity function stored...Why can't we just know which one to use and use it! - Iterating and scoring quantized vectors. By default it would be good to be able to iterate and score quantized vectors (e.g. without going through the HNSW graph). This significantly hampers support for true exact kNN search. This commit extends the vector value iterators to be able to return a scorer given some vector value (what this PR demonstrates). The scorer contains a copy of the originating iterator and allows for iteration and scoring the most optimized way the provided codec can give. Users can still iterate vector values directly, read them on heap, and score any way they please.
This commit is contained in:
parent
8d7e4174af
commit
b60e86c4b9
|
@ -258,6 +258,9 @@ New Features
|
|||
* GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add
|
||||
new Hnsw format for binary quantized vectors. (Ben Trent)
|
||||
|
||||
* GITHUB#13181: Add new VectorScorer interface to vector value iterators. This allows for vector codecs to supply
|
||||
simpler and more optimized vector scoring when iterating vector values directly. (Ben Trent)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -34,7 +34,9 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -272,7 +274,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
throws IOException {
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
return new OffHeapFloatVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice);
|
||||
return new OffHeapFloatVectorValues(
|
||||
fieldEntry.dimension, fieldEntry.ordToDoc, fieldEntry.similarityFunction, bytesSlice);
|
||||
}
|
||||
|
||||
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
|
||||
|
@ -359,14 +362,20 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
final int byteSize;
|
||||
int lastOrd = -1;
|
||||
final float[] value;
|
||||
final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
int ord = -1;
|
||||
int doc = -1;
|
||||
|
||||
OffHeapFloatVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) {
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int[] ordToDoc,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
IndexInput dataIn) {
|
||||
this.dimension = dimension;
|
||||
this.ordToDoc = ordToDoc;
|
||||
this.dataIn = dataIn;
|
||||
this.similarityFunction = similarityFunction;
|
||||
|
||||
byteSize = Float.BYTES * dimension;
|
||||
value = new float[dimension];
|
||||
|
@ -420,7 +429,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() {
|
||||
return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone());
|
||||
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -433,6 +442,22 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
lastOrd = targetOrd;
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
OffHeapFloatVectorValues values = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.similarityFunction.compare(values.vectorValue(), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
|
|
|
@ -35,7 +35,9 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -255,7 +257,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
return new OffHeapFloatVectorValues(
|
||||
fieldEntry.dimension, fieldEntry.size(), fieldEntry.ordToDoc, bytesSlice);
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.size(),
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.similarityFunction,
|
||||
bytesSlice);
|
||||
}
|
||||
|
||||
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
|
||||
|
@ -399,16 +405,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
private final IndexInput dataIn;
|
||||
private final int byteSize;
|
||||
private final float[] value;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
private int ord = -1;
|
||||
private int doc = -1;
|
||||
|
||||
OffHeapFloatVectorValues(int dimension, int size, int[] ordToDoc, IndexInput dataIn) {
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
int[] ordToDoc,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
IndexInput dataIn) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.ordToDoc = ordToDoc;
|
||||
ordToDocOperator = ordToDoc == null ? IntUnaryOperator.identity() : (ord) -> ordToDoc[ord];
|
||||
this.dataIn = dataIn;
|
||||
this.similarityFunction = similarityFunction;
|
||||
byteSize = Float.BYTES * dimension;
|
||||
value = new float[dimension];
|
||||
}
|
||||
|
@ -468,7 +481,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() {
|
||||
return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone());
|
||||
return new OffHeapFloatVectorValues(
|
||||
dimension, size, ordToDoc, similarityFunction, dataIn.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -477,6 +491,22 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
dataIn.readFloats(value, 0, value.length);
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
OffHeapFloatVectorValues values = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.similarityFunction.compare(values.vectorValue(), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
|
|
|
@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene92;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
protected final int byteSize;
|
||||
protected int lastOrd = -1;
|
||||
protected final float[] value;
|
||||
protected final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
;
|
||||
|
||||
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) {
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
IndexInput slice) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
byteSize = Float.BYTES * dimension;
|
||||
value = new float[dimension];
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -75,9 +85,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
vectorData.slice(
|
||||
"vector-data", fieldEntry.vectorDataOffset(), fieldEntry.vectorDataLength());
|
||||
if (fieldEntry.docsWithFieldOffset() == -1) {
|
||||
return new DenseOffHeapVectorValues(fieldEntry.dimension(), fieldEntry.size(), bytesSlice);
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension(), fieldEntry.size(), fieldEntry.similarityFunction(), bytesSlice);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice);
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, vectorData, fieldEntry.similarityFunction(), bytesSlice);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,8 +97,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
||||
super(dimension, size, slice);
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
IndexInput slice) {
|
||||
super(dimension, size, vectorSimilarityFunction, slice);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -115,13 +131,29 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||
return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
DenseOffHeapVectorValues values = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
@ -132,10 +164,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
private final Lucene92HnswVectorsReader.FieldEntry fieldEntry;
|
||||
|
||||
public SparseOffHeapVectorValues(
|
||||
Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice)
|
||||
Lucene92HnswVectorsReader.FieldEntry fieldEntry,
|
||||
IndexInput dataIn,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
IndexInput slice)
|
||||
throws IOException {
|
||||
|
||||
super(fieldEntry.dimension(), fieldEntry.size(), slice);
|
||||
super(fieldEntry.dimension(), fieldEntry.size(), vectorSimilarityFunction, slice);
|
||||
this.fieldEntry = fieldEntry;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
|
||||
|
@ -173,8 +208,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, dataIn, vectorSimilarityFunction, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -199,12 +235,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
SparseOffHeapVectorValues values = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null);
|
||||
super(dimension, 0, VectorSimilarityFunction.COSINE, null);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -258,5 +310,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,9 @@ import java.nio.ByteBuffer;
|
|||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -39,12 +42,19 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
protected final byte[] binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
protected final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
||||
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
OffHeapByteVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
int byteSize) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
binaryValue = byteBuffer.array();
|
||||
}
|
||||
|
@ -85,9 +95,14 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
int byteSize = fieldEntry.dimension();
|
||||
if (fieldEntry.docsWithFieldOffset() == -1) {
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize);
|
||||
fieldEntry.dimension(),
|
||||
fieldEntry.size(),
|
||||
bytesSlice,
|
||||
fieldEntry.similarityFunction(),
|
||||
byteSize);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,8 +110,13 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
int byteSize) {
|
||||
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -124,14 +144,31 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public OffHeapByteVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
DenseOffHeapVectorValues copy = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
@ -145,10 +182,11 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
int byteSize)
|
||||
throws IOException {
|
||||
|
||||
super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize);
|
||||
super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize);
|
||||
this.fieldEntry = fieldEntry;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
|
||||
|
@ -186,8 +224,9 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public OffHeapByteVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -212,12 +251,28 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
SparseOffHeapVectorValues copy = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null, 0);
|
||||
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -271,5 +326,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene94;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
protected final int byteSize;
|
||||
protected int lastOrd = -1;
|
||||
protected final float[] value;
|
||||
protected final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
||||
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
int byteSize) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
value = new float[dimension];
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -81,9 +91,14 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
};
|
||||
if (fieldEntry.docsWithFieldOffset() == -1) {
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize);
|
||||
fieldEntry.dimension(),
|
||||
fieldEntry.size(),
|
||||
bytesSlice,
|
||||
fieldEntry.similarityFunction(),
|
||||
byteSize);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,8 +106,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
int byteSize) {
|
||||
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -120,14 +140,31 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
DenseOffHeapVectorValues values = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
@ -141,10 +178,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
int byteSize)
|
||||
throws IOException {
|
||||
|
||||
super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize);
|
||||
super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize);
|
||||
this.fieldEntry = fieldEntry;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
|
||||
|
@ -182,8 +220,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -208,12 +247,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
SparseOffHeapVectorValues values = this.copy();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null, 0);
|
||||
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -267,5 +322,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -253,6 +253,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapFloatVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
defaultFlatVectorScorer,
|
||||
fieldEntry.ordToDocVectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
@ -274,6 +276,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
+ VectorEncoding.BYTE);
|
||||
}
|
||||
return OffHeapByteVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
defaultFlatVectorScorer,
|
||||
fieldEntry.ordToDocVectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
@ -295,6 +299,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
|
||||
OffHeapFloatVectorValues vectorValues =
|
||||
OffHeapFloatVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
defaultFlatVectorScorer,
|
||||
fieldEntry.ordToDocVectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
@ -324,6 +330,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
|
|||
|
||||
OffHeapByteVectorValues vectorValues =
|
||||
OffHeapByteVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
defaultFlatVectorScorer,
|
||||
fieldEntry.ordToDocVectorValues,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
|
|
@ -133,7 +133,10 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// build the graph using the temporary vector data
|
||||
Lucene90HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors =
|
||||
new Lucene90HnswVectorsReader.OffHeapFloatVectorValues(
|
||||
floatVectorValues.dimension(), docIds, vectorDataInput);
|
||||
floatVectorValues.dimension(),
|
||||
docIds,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
vectorDataInput);
|
||||
|
||||
long[] offsets = new long[docIds.length];
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
|
|
|
@ -68,4 +68,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
|||
public void testSortedIndexBytes() throws Exception {
|
||||
// unimplemented
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testByteVectorScorerIteration() {
|
||||
// unimplemented
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,7 +138,11 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
Lucene91HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors =
|
||||
new Lucene91HnswVectorsReader.OffHeapFloatVectorValues(
|
||||
floatVectorValues.dimension(), docsWithField.cardinality(), null, vectorDataInput);
|
||||
floatVectorValues.dimension(),
|
||||
docsWithField.cardinality(),
|
||||
null,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
vectorDataInput);
|
||||
Lucene91OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
|
|
|
@ -67,4 +67,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
|||
public void testSortedIndexBytes() throws Exception {
|
||||
// unimplemented
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testByteVectorScorerIteration() {
|
||||
// unimplemented
|
||||
}
|
||||
}
|
||||
|
|
|
@ -146,7 +146,10 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
OffHeapFloatVectorValues offHeapVectors =
|
||||
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||
floatVectorValues.dimension(), docsWithField.cardinality(), vectorDataInput);
|
||||
floatVectorValues.dimension(),
|
||||
docsWithField.cardinality(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
vectorDataInput);
|
||||
OnHeapHnswGraph graph =
|
||||
offHeapVectors.size() == 0
|
||||
? null
|
||||
|
|
|
@ -57,4 +57,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
|||
public void testSortedIndexBytes() throws Exception {
|
||||
// unimplemented
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testByteVectorScorerIteration() {
|
||||
// unimplemented
|
||||
}
|
||||
}
|
||||
|
|
|
@ -421,6 +421,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
byteSize);
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
|
@ -437,6 +438,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
byteSize);
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
|
|
|
@ -70,6 +70,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
|
||||
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||
private boolean finished;
|
||||
|
@ -437,7 +438,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
OnHeapHnswGraph graph = null;
|
||||
int[][] vectorIndexNodeOffsets = null;
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
final RandomVectorScorerSupplier scorerSupplier;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
|
@ -448,7 +448,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize));
|
||||
byteSize,
|
||||
defaultFlatVectorScorer,
|
||||
fieldInfo.getVectorSimilarityFunction()));
|
||||
break;
|
||||
case FLOAT32:
|
||||
scorerSupplier =
|
||||
|
@ -458,7 +460,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize));
|
||||
byteSize,
|
||||
defaultFlatVectorScorer,
|
||||
fieldInfo.getVectorSimilarityFunction()));
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
|
@ -667,6 +671,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<T> vectors;
|
||||
private final HnswGraphBuilder hnswGraphBuilder;
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
|
||||
private int lastDocID = -1;
|
||||
private int node = 0;
|
||||
|
@ -697,7 +702,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.BufferedChecksumIndexInput;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
|
@ -92,7 +93,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
assert fieldEntries.containsKey(fieldName) == false;
|
||||
fieldEntries.put(
|
||||
fieldName, new FieldEntry(dimension, vectorDataOffset, vectorDataLength, docIds));
|
||||
fieldName,
|
||||
new FieldEntry(
|
||||
dimension,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
docIds,
|
||||
readState.fieldInfos.fieldInfo(fieldName).getVectorSimilarityFunction()));
|
||||
fieldNumber = readInt(in, FIELD_NUMBER);
|
||||
}
|
||||
SimpleTextUtil.checkFooter(in);
|
||||
|
@ -275,7 +282,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
private record FieldEntry(
|
||||
int dimension, long vectorDataOffset, long vectorDataLength, int[] ordToDoc) {
|
||||
int dimension,
|
||||
long vectorDataOffset,
|
||||
long vectorDataLength,
|
||||
int[] ordToDoc,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
int size() {
|
||||
return ordToDoc.length;
|
||||
}
|
||||
|
@ -298,6 +309,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
readAllVectors();
|
||||
}
|
||||
|
||||
private SimpleTextFloatVectorValues(SimpleTextFloatVectorValues other) {
|
||||
this.entry = other.entry;
|
||||
this.in = other.in.clone();
|
||||
this.values = other.values;
|
||||
this.curOrd = other.curOrd;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return entry.dimension;
|
||||
|
@ -340,6 +358,25 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
SimpleTextFloatVectorValues simpleTextFloatVectorValues =
|
||||
new SimpleTextFloatVectorValues(this);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return entry
|
||||
.similarityFunction()
|
||||
.compare(simpleTextFloatVectorValues.vectorValue(), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return simpleTextFloatVectorValues;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void readAllVectors() throws IOException {
|
||||
for (float[] value : values) {
|
||||
readVector(value);
|
||||
|
@ -379,6 +416,15 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
readAllVectors();
|
||||
}
|
||||
|
||||
private SimpleTextByteVectorValues(SimpleTextByteVectorValues other) {
|
||||
this.entry = other.entry;
|
||||
this.in = other.in.clone();
|
||||
this.values = other.values;
|
||||
this.binaryValue = new BytesRef(entry.dimension);
|
||||
this.binaryValue.length = binaryValue.bytes.length;
|
||||
this.curOrd = other.curOrd;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return entry.dimension;
|
||||
|
@ -422,6 +468,24 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return entry
|
||||
.similarityFunction()
|
||||
.compare(simpleTextByteVectorValues.vectorValue(), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return simpleTextByteVectorValues;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void readAllVectors() throws IOException {
|
||||
for (byte[] value : values) {
|
||||
readVector(value);
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
|
@ -159,6 +160,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
|
@ -216,6 +222,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -354,6 +365,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
private static class BufferedByteVectorValues extends ByteVectorValues {
|
||||
|
@ -414,5 +430,10 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.lucene.index.MergeState;
|
|||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
|
||||
/** Writes vectors to an index. */
|
||||
|
@ -188,7 +189,6 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
private final List<VectorValuesSub> subs;
|
||||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
|
||||
private int docId;
|
||||
VectorValuesSub current;
|
||||
|
||||
|
@ -239,6 +239,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
public int dimension() {
|
||||
return subs.get(0).values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
static class MergedByteVectorValues extends ByteVectorValues {
|
||||
|
@ -296,6 +301,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
public int dimension() {
|
||||
return subs.get(0).values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,13 +19,18 @@ package org.apache.lucene.codecs.lucene95;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
|
@ -39,14 +44,24 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
protected final byte[] binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
protected final VectorSimilarityFunction similarityFunction;
|
||||
protected final FlatVectorsScorer flatVectorsScorer;
|
||||
|
||||
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
OffHeapByteVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
int byteSize,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
binaryValue = byteBuffer.array();
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.flatVectorsScorer = flatVectorsScorer;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -79,6 +94,8 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
public static OffHeapByteVectorValues load(
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
OrdToDocDISIReaderConfiguration configuration,
|
||||
VectorEncoding vectorEncoding,
|
||||
int dimension,
|
||||
|
@ -87,14 +104,26 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
IndexInput vectorData)
|
||||
throws IOException {
|
||||
if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) {
|
||||
return new EmptyOffHeapVectorValues(dimension);
|
||||
return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction);
|
||||
}
|
||||
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
|
||||
if (configuration.isDense()) {
|
||||
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, dimension);
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension,
|
||||
configuration.size,
|
||||
bytesSlice,
|
||||
dimension,
|
||||
flatVectorsScorer,
|
||||
vectorSimilarityFunction);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, vectorData, bytesSlice, dimension, dimension);
|
||||
configuration,
|
||||
vectorData,
|
||||
bytesSlice,
|
||||
dimension,
|
||||
dimension,
|
||||
flatVectorsScorer,
|
||||
vectorSimilarityFunction);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,8 +135,14 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
int byteSize,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction vectorSimilarityFunction) {
|
||||
super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,13 +171,32 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
DenseOffHeapVectorValues copy = copy();
|
||||
RandomVectorScorer scorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return scorer.score(copy.doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
@ -157,10 +211,18 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
int dimension,
|
||||
int byteSize)
|
||||
int byteSize,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction vectorSimilarityFunction)
|
||||
throws IOException {
|
||||
|
||||
super(dimension, configuration.size, slice, byteSize);
|
||||
super(
|
||||
dimension,
|
||||
configuration.size,
|
||||
slice,
|
||||
byteSize,
|
||||
flatVectorsScorer,
|
||||
vectorSimilarityFunction);
|
||||
this.configuration = configuration;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength);
|
||||
|
@ -200,7 +262,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, dataIn, slice.clone(), dimension, byteSize);
|
||||
configuration,
|
||||
dataIn,
|
||||
slice.clone(),
|
||||
dimension,
|
||||
byteSize,
|
||||
flatVectorsScorer,
|
||||
similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -225,12 +293,33 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
SparseOffHeapVectorValues copy = copy();
|
||||
RandomVectorScorer scorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return scorer.score(copy.disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null, 0);
|
||||
public EmptyOffHeapVectorValues(
|
||||
int dimension,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction vectorSimilarityFunction) {
|
||||
super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -284,5 +373,10 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,13 +18,18 @@
|
|||
package org.apache.lucene.codecs.lucene95;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
|
@ -37,12 +42,22 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
protected final int byteSize;
|
||||
protected int lastOrd = -1;
|
||||
protected final float[] value;
|
||||
protected final VectorSimilarityFunction similarityFunction;
|
||||
protected final FlatVectorsScorer flatVectorsScorer;
|
||||
|
||||
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
int byteSize,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.flatVectorsScorer = flatVectorsScorer;
|
||||
value = new float[dimension];
|
||||
}
|
||||
|
||||
|
@ -73,6 +88,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
public static OffHeapFloatVectorValues load(
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
OrdToDocDISIReaderConfiguration configuration,
|
||||
VectorEncoding vectorEncoding,
|
||||
int dimension,
|
||||
|
@ -81,15 +98,27 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
IndexInput vectorData)
|
||||
throws IOException {
|
||||
if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return new EmptyOffHeapVectorValues(dimension);
|
||||
return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction);
|
||||
}
|
||||
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
|
||||
int byteSize = dimension * Float.BYTES;
|
||||
if (configuration.docsWithFieldOffset == -1) {
|
||||
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension,
|
||||
configuration.size,
|
||||
bytesSlice,
|
||||
byteSize,
|
||||
flatVectorsScorer,
|
||||
vectorSimilarityFunction);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, vectorData, bytesSlice, dimension, byteSize);
|
||||
configuration,
|
||||
vectorData,
|
||||
bytesSlice,
|
||||
dimension,
|
||||
byteSize,
|
||||
flatVectorsScorer,
|
||||
vectorSimilarityFunction);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,8 +130,14 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput slice,
|
||||
int byteSize,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -131,13 +166,32 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
DenseOffHeapVectorValues copy = copy();
|
||||
RandomVectorScorer randomVectorScorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return randomVectorScorer.score(copy.doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
@ -152,10 +206,12 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
int dimension,
|
||||
int byteSize)
|
||||
int byteSize,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
super(dimension, configuration.size, slice, byteSize);
|
||||
super(dimension, configuration.size, slice, byteSize, flatVectorsScorer, similarityFunction);
|
||||
this.configuration = configuration;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength);
|
||||
|
@ -195,7 +251,13 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, dataIn, slice.clone(), dimension, byteSize);
|
||||
configuration,
|
||||
dataIn,
|
||||
slice.clone(),
|
||||
dimension,
|
||||
byteSize,
|
||||
flatVectorsScorer,
|
||||
similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -220,12 +282,33 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
SparseOffHeapVectorValues copy = copy();
|
||||
RandomVectorScorer randomVectorScorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return randomVectorScorer.score(copy.disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null, 0);
|
||||
public EmptyOffHeapVectorValues(
|
||||
int dimension,
|
||||
FlatVectorsScorer flatVectorsScorer,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -256,17 +339,17 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
public int advance(int target) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmptyOffHeapVectorValues copy() throws IOException {
|
||||
public EmptyOffHeapVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
@ -279,5 +362,10 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -185,6 +185,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapFloatVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
@ -206,6 +208,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
+ VectorEncoding.BYTE);
|
||||
}
|
||||
return OffHeapByteVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
@ -223,6 +227,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
return vectorScorer.getRandomVectorScorer(
|
||||
fieldEntry.similarityFunction,
|
||||
OffHeapFloatVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
@ -241,6 +247,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
|||
return vectorScorer.getRandomVectorScorer(
|
||||
fieldEntry.similarityFunction,
|
||||
OffHeapByteVectorValues.load(
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
|
|
|
@ -314,14 +314,18 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
finalVectorDataInput,
|
||||
fieldInfo.getVectorDimension() * Byte.BYTES));
|
||||
fieldInfo.getVectorDimension() * Byte.BYTES,
|
||||
vectorsScorer,
|
||||
fieldInfo.getVectorSimilarityFunction()));
|
||||
case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
finalVectorDataInput,
|
||||
fieldInfo.getVectorDimension() * Float.BYTES));
|
||||
fieldInfo.getVectorDimension() * Float.BYTES,
|
||||
vectorsScorer,
|
||||
fieldInfo.getVectorSimilarityFunction()));
|
||||
};
|
||||
return new FlatCloseableRandomVectorScorerSupplier(
|
||||
() -> {
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -164,7 +165,24 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
return rawVectorsReader.getFloatVectorValues(field);
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
|
||||
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
|
||||
OffHeapQuantizedByteVectorValues.load(
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.scalarQuantizer,
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.compress,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
quantizedVectorData);
|
||||
return new QuantizedVectorValues(rawVectorValues, quantizedByteVectorValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -227,6 +245,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.scalarQuantizer,
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.compress,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
|
@ -282,6 +302,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.scalarQuantizer,
|
||||
fieldEntry.similarityFunction,
|
||||
vectorScorer,
|
||||
fieldEntry.compress,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
|
@ -369,4 +391,56 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc);
|
||||
}
|
||||
}
|
||||
|
||||
private static final class QuantizedVectorValues extends FloatVectorValues {
|
||||
private final FloatVectorValues rawVectorValues;
|
||||
private final OffHeapQuantizedByteVectorValues quantizedVectorValues;
|
||||
|
||||
QuantizedVectorValues(
|
||||
FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) {
|
||||
this.rawVectorValues = rawVectorValues;
|
||||
this.quantizedVectorValues = quantizedVectorValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return rawVectorValues.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return rawVectorValues.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return rawVectorValues.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return rawVectorValues.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int rawDocId = rawVectorValues.nextDoc();
|
||||
int quantizedDocId = quantizedVectorValues.nextDoc();
|
||||
assert rawDocId == quantizedDocId;
|
||||
return quantizedDocId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
int rawDocId = rawVectorValues.advance(target);
|
||||
int quantizedDocId = quantizedVectorValues.advance(target);
|
||||
assert rawDocId == quantizedDocId;
|
||||
return quantizedDocId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
return quantizedVectorValues.vectorScorer(query);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,6 +49,7 @@ import org.apache.lucene.index.Sorter;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
@ -526,6 +527,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
docsWithField.cardinality(),
|
||||
mergedQuantizationState,
|
||||
compress,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
vectorsScorer,
|
||||
quantizationDataInput)));
|
||||
} finally {
|
||||
if (success == false) {
|
||||
|
@ -890,6 +893,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
curDoc = target;
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
|
||||
|
@ -1013,6 +1021,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
public float getScoreCorrectionConstant() throws IOException {
|
||||
return current.values.getScoreCorrectionConstant();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer vectorScorer(float[] target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
|
||||
|
@ -1082,6 +1095,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer vectorScorer(float[] target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
private void quantize() throws IOException {
|
||||
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||
|
@ -1182,5 +1200,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
public int advance(int target) throws IOException {
|
||||
return in.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer vectorScorer(float[] target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,10 +19,15 @@ package org.apache.lucene.codecs.lucene99;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
|
@ -39,6 +44,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
protected final int size;
|
||||
protected final int numBytes;
|
||||
protected final ScalarQuantizer scalarQuantizer;
|
||||
protected final VectorSimilarityFunction similarityFunction;
|
||||
protected final FlatVectorsScorer vectorsScorer;
|
||||
protected final boolean compress;
|
||||
|
||||
protected final IndexInput slice;
|
||||
|
@ -85,6 +92,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
int dimension,
|
||||
int size,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
FlatVectorsScorer vectorsScorer,
|
||||
boolean compress,
|
||||
IndexInput slice) {
|
||||
this.dimension = dimension;
|
||||
|
@ -100,6 +109,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
this.byteSize = this.numBytes + Float.BYTES;
|
||||
byteBuffer = ByteBuffer.allocate(dimension);
|
||||
binaryValue = byteBuffer.array();
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.vectorsScorer = vectorsScorer;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -160,22 +171,39 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
int dimension,
|
||||
int size,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
FlatVectorsScorer vectorsScorer,
|
||||
boolean compress,
|
||||
long quantizedVectorDataOffset,
|
||||
long quantizedVectorDataLength,
|
||||
IndexInput vectorData)
|
||||
throws IOException {
|
||||
if (configuration.isEmpty()) {
|
||||
return new EmptyOffHeapVectorValues(dimension);
|
||||
return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice(
|
||||
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
|
||||
if (configuration.isDense()) {
|
||||
return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, bytesSlice);
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension,
|
||||
size,
|
||||
scalarQuantizer,
|
||||
compress,
|
||||
similarityFunction,
|
||||
vectorsScorer,
|
||||
bytesSlice);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, dimension, size, scalarQuantizer, compress, vectorData, bytesSlice);
|
||||
configuration,
|
||||
dimension,
|
||||
size,
|
||||
scalarQuantizer,
|
||||
compress,
|
||||
vectorData,
|
||||
similarityFunction,
|
||||
vectorsScorer,
|
||||
bytesSlice);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,8 +220,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
int size,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
boolean compress,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
FlatVectorsScorer vectorsScorer,
|
||||
IndexInput slice) {
|
||||
super(dimension, size, scalarQuantizer, compress, slice);
|
||||
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -223,13 +253,37 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, scalarQuantizer, compress, slice.clone());
|
||||
dimension,
|
||||
size,
|
||||
scalarQuantizer,
|
||||
compress,
|
||||
similarityFunction,
|
||||
vectorsScorer,
|
||||
slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer vectorScorer(float[] target) throws IOException {
|
||||
DenseOffHeapVectorValues copy = copy();
|
||||
RandomVectorScorer vectorScorer =
|
||||
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorScorer.score(copy.doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
@ -246,9 +300,11 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
ScalarQuantizer scalarQuantizer,
|
||||
boolean compress,
|
||||
IndexInput dataIn,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
FlatVectorsScorer vectorsScorer,
|
||||
IndexInput slice)
|
||||
throws IOException {
|
||||
super(dimension, size, scalarQuantizer, compress, slice);
|
||||
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
|
||||
this.configuration = configuration;
|
||||
this.dataIn = dataIn;
|
||||
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
|
||||
|
@ -279,7 +335,15 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
configuration, dimension, size, scalarQuantizer, compress, dataIn, slice.clone());
|
||||
configuration,
|
||||
dimension,
|
||||
size,
|
||||
scalarQuantizer,
|
||||
compress,
|
||||
dataIn,
|
||||
similarityFunction,
|
||||
vectorsScorer,
|
||||
slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -304,12 +368,40 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer vectorScorer(float[] target) throws IOException {
|
||||
SparseOffHeapVectorValues copy = copy();
|
||||
RandomVectorScorer vectorScorer =
|
||||
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorScorer.score(copy.disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), false, null);
|
||||
public EmptyOffHeapVectorValues(
|
||||
int dimension,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
FlatVectorsScorer vectorsScorer) {
|
||||
super(
|
||||
dimension,
|
||||
0,
|
||||
new ScalarQuantizer(-1, 1, (byte) 7),
|
||||
similarityFunction,
|
||||
vectorsScorer,
|
||||
false,
|
||||
null);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
@ -363,5 +455,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer vectorScorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.index;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
||||
/**
|
||||
* This class provides access to per-document floating point vector values indexed as {@link
|
||||
|
@ -75,4 +76,14 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
|
|||
+ ")");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not
|
||||
* the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and
|
||||
* iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}.
|
||||
*
|
||||
* @param query the query vector
|
||||
* @return a {@link VectorScorer} instance
|
||||
*/
|
||||
public abstract VectorScorer scorer(byte[] query) throws IOException;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.lucene.index.FilterLeafReader.FilterTerms;
|
|||
import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.automaton.CompiledAutomaton;
|
||||
|
@ -476,6 +477,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return vectorValues.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
|
||||
* if {@link Thread#interrupted()} returns true.
|
||||
|
@ -543,6 +549,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return vectorValues.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) throws IOException {
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
|
||||
* if {@link Thread#interrupted()} returns true.
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.index;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
||||
/**
|
||||
* This class provides access to per-document floating point vector values indexed as {@link
|
||||
|
@ -75,4 +76,15 @@ public abstract class FloatVectorValues extends DocIdSetIterator {
|
|||
+ ")");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a {@link VectorScorer} for the given query vector and the current {@link
|
||||
* FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for
|
||||
* this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the
|
||||
* iteration of this {@link FloatVectorValues}.
|
||||
*
|
||||
* @param query the query vector
|
||||
* @return a {@link VectorScorer} instance
|
||||
*/
|
||||
public abstract VectorScorer scorer(float[] query) throws IOException;
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues;
|
|||
import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
@ -883,6 +884,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
|
|||
public int advance(int target) throws IOException {
|
||||
return mergedIterator.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -937,6 +943,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
|
|||
public int advance(int target) throws IOException {
|
||||
return mergedIterator.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.apache.lucene.codecs.TermVectorsReader;
|
|||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.IOSupplier;
|
||||
|
@ -266,6 +267,11 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
return docId = docsWithField.nextSetBit(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||
|
@ -320,6 +326,11 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
return docId = docsWithField.nextSetBit(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -205,17 +205,17 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
HitQueue queue = new HitQueue(queueSize, true);
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
ScoreDoc topDoc = queue.top();
|
||||
DocIdSetIterator vectorIterator = vectorScorer.iterator();
|
||||
DocIdSetIterator conjunction =
|
||||
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
|
||||
int doc;
|
||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
// Mark results as partial if timeout is met
|
||||
if (queryTimeout != null && queryTimeout.shouldExit()) {
|
||||
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||
break;
|
||||
}
|
||||
|
||||
boolean advanced = vectorScorer.advanceExact(doc);
|
||||
assert advanced;
|
||||
|
||||
assert vectorIterator.docID() == doc;
|
||||
float score = vectorScorer.score();
|
||||
if (score > topDoc.score) {
|
||||
topDoc.score = score;
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.search;
|
|||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
|
@ -83,7 +84,10 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
VectorScorer scorer = createVectorScorer(context);
|
||||
if (scorer == null) {
|
||||
return Explanation.noMatch("Not indexed as the correct vector field");
|
||||
} else if (scorer.advanceExact(doc)) {
|
||||
}
|
||||
DocIdSetIterator iterator = scorer.iterator();
|
||||
int docId = iterator.advance(doc);
|
||||
if (docId == doc) {
|
||||
float score = scorer.score();
|
||||
if (score >= resultSimilarity) {
|
||||
return Explanation.match(boost * score, "Score above threshold");
|
||||
|
@ -256,15 +260,15 @@ abstract class AbstractVectorSimilarityQuery extends Query {
|
|||
DocIdSetIterator acceptDocs,
|
||||
float threshold) {
|
||||
float[] cachedScore = new float[1];
|
||||
DocIdSetIterator vectorIterator = scorer.iterator();
|
||||
DocIdSetIterator conjunction =
|
||||
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptDocs), List.of());
|
||||
DocIdSetIterator iterator =
|
||||
new FilteredDocIdSetIterator(acceptDocs) {
|
||||
new FilteredDocIdSetIterator(conjunction) {
|
||||
@Override
|
||||
protected boolean match(int doc) throws IOException {
|
||||
// Advance the scorer
|
||||
if (!scorer.advanceExact(doc)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
assert doc == vectorIterator.docID();
|
||||
// Compute the dot product
|
||||
float score = scorer.score();
|
||||
cachedScore[0] = score * boost;
|
||||
|
|
|
@ -21,9 +21,8 @@ import java.util.Arrays;
|
|||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
|
@ -98,12 +97,11 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
|
|||
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
|
||||
@SuppressWarnings("resource")
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field);
|
||||
if (vectorValues == null) {
|
||||
return null;
|
||||
}
|
||||
return VectorScorer.create(context, fi, target);
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.util.Arrays;
|
|||
import java.util.Objects;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
||||
/**
|
||||
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
|
||||
|
@ -37,26 +36,13 @@ class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
|
|||
}
|
||||
|
||||
@Override
|
||||
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
|
||||
public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
|
||||
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
ByteVectorValues.checkField(ctx.reader(), fieldName);
|
||||
return DoubleValues.EMPTY;
|
||||
return null;
|
||||
}
|
||||
VectorSimilarityFunction function =
|
||||
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
|
||||
return new DoubleValues() {
|
||||
@Override
|
||||
public double doubleValue() throws IOException {
|
||||
return function.compare(queryVector, vectorValues.vectorValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
return doc >= vectorValues.docID()
|
||||
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
|
||||
}
|
||||
};
|
||||
return vectorValues.scorer(queryVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,9 +21,8 @@ import java.util.Arrays;
|
|||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
|
@ -100,11 +99,11 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
|
|||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
|
||||
@SuppressWarnings("resource")
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field);
|
||||
if (vectorValues == null) {
|
||||
return null;
|
||||
}
|
||||
return VectorScorer.create(context, fi, target);
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.util.Arrays;
|
|||
import java.util.Objects;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
||||
/**
|
||||
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
|
||||
|
@ -44,22 +43,32 @@ class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
|
|||
FloatVectorValues.checkField(ctx.reader(), fieldName);
|
||||
return DoubleValues.EMPTY;
|
||||
}
|
||||
VectorSimilarityFunction function =
|
||||
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
|
||||
return new DoubleValues() {
|
||||
private final VectorScorer scorer = vectorValues.scorer(queryVector);
|
||||
private final DocIdSetIterator iterator = scorer.iterator();
|
||||
|
||||
@Override
|
||||
public double doubleValue() throws IOException {
|
||||
return function.compare(queryVector, vectorValues.vectorValue());
|
||||
return scorer.score();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
return doc >= vectorValues.docID()
|
||||
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
|
||||
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
|
||||
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
FloatVectorValues.checkField(ctx.reader(), fieldName);
|
||||
return null;
|
||||
}
|
||||
return vectorValues.scorer(queryVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(fieldName, Arrays.hashCode(queryVector));
|
||||
|
|
|
@ -98,7 +98,12 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
|
||||
return VectorScorer.create(context, fi, target);
|
||||
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field);
|
||||
if (vectorValues == null) {
|
||||
ByteVectorValues.checkField(context.reader(), field);
|
||||
return null;
|
||||
}
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -99,7 +99,12 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
|
|||
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
|
||||
return VectorScorer.create(context, fi, target);
|
||||
FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field);
|
||||
if (vectorValues == null) {
|
||||
FloatVectorValues.checkField(context.reader(), field);
|
||||
return null;
|
||||
}
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,117 +17,25 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
||||
/**
|
||||
* Computes the similarity score between a given query vector and different document vectors. This
|
||||
* is primarily used by {@link KnnFloatVectorQuery} to run an exact, exhaustive search over the
|
||||
* vectors.
|
||||
* is used for exact searching and scoring
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
abstract class VectorScorer {
|
||||
protected final VectorSimilarityFunction similarity;
|
||||
public interface VectorScorer {
|
||||
|
||||
/**
|
||||
* Create a new vector scorer instance.
|
||||
* Compute the score for the current document ID.
|
||||
*
|
||||
* @param context the reader context
|
||||
* @param fi the FieldInfo for the field containing document vectors
|
||||
* @param query the query vector to compute the similarity for
|
||||
* @return the score for the current document ID
|
||||
* @throws IOException if an exception occurs during score computation
|
||||
*/
|
||||
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
|
||||
throws IOException {
|
||||
FloatVectorValues values = context.reader().getFloatVectorValues(fi.name);
|
||||
if (values == null) {
|
||||
FloatVectorValues.checkField(context.reader(), fi.name);
|
||||
return null;
|
||||
}
|
||||
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
return new FloatVectorScorer(values, query, similarity);
|
||||
}
|
||||
float score() throws IOException;
|
||||
|
||||
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
|
||||
throws IOException {
|
||||
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
|
||||
if (values == null) {
|
||||
ByteVectorValues.checkField(context.reader(), fi.name);
|
||||
return null;
|
||||
}
|
||||
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
return new ByteVectorScorer(values, query, similarity);
|
||||
}
|
||||
|
||||
VectorScorer(VectorSimilarityFunction similarity) {
|
||||
this.similarity = similarity;
|
||||
}
|
||||
|
||||
/** Compute the similarity score for the current document. */
|
||||
abstract float score() throws IOException;
|
||||
|
||||
abstract boolean advanceExact(int doc) throws IOException;
|
||||
|
||||
private static class ByteVectorScorer extends VectorScorer {
|
||||
private final byte[] query;
|
||||
private final ByteVectorValues values;
|
||||
|
||||
protected ByteVectorScorer(
|
||||
ByteVectorValues values, byte[] query, VectorSimilarityFunction similarity) {
|
||||
super(similarity);
|
||||
this.values = values;
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the instance to the given document ID and return true if there is a value for that
|
||||
* document.
|
||||
*/
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
int vectorDoc = values.docID();
|
||||
if (vectorDoc < doc) {
|
||||
vectorDoc = values.advance(doc);
|
||||
}
|
||||
return vectorDoc == doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned";
|
||||
return similarity.compare(query, values.vectorValue());
|
||||
}
|
||||
}
|
||||
|
||||
private static class FloatVectorScorer extends VectorScorer {
|
||||
private final float[] query;
|
||||
private final FloatVectorValues values;
|
||||
|
||||
protected FloatVectorScorer(
|
||||
FloatVectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
||||
super(similarity);
|
||||
this.query = query;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the instance to the given document ID and return true if there is a value for that
|
||||
* document.
|
||||
*/
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
int vectorDoc = values.docID();
|
||||
if (vectorDoc < doc) {
|
||||
vectorDoc = values.advance(doc);
|
||||
}
|
||||
return vectorDoc == doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned";
|
||||
return similarity.compare(query, values.vectorValue());
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @return a {@link DocIdSetIterator} over the documents.
|
||||
*/
|
||||
DocIdSetIterator iterator();
|
||||
}
|
||||
|
|
|
@ -33,8 +33,26 @@ abstract class VectorSimilarityValuesSource extends DoubleValuesSource {
|
|||
}
|
||||
|
||||
@Override
|
||||
public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores)
|
||||
throws IOException;
|
||||
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
|
||||
VectorScorer scorer = getScorer(ctx);
|
||||
if (scorer == null) {
|
||||
return DoubleValues.EMPTY;
|
||||
}
|
||||
DocIdSetIterator iterator = scorer.iterator();
|
||||
return new DoubleValues() {
|
||||
@Override
|
||||
public double doubleValue() throws IOException {
|
||||
return scorer.score();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
protected abstract VectorScorer getScorer(LeafReaderContext ctx) throws IOException;
|
||||
|
||||
@Override
|
||||
public boolean needsScores() {
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.apache.lucene.util.quantization;
|
|||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
||||
/**
|
||||
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
|
||||
|
@ -25,6 +27,31 @@ import org.apache.lucene.index.ByteVectorValues;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class QuantizedByteVectorValues extends ByteVectorValues {
|
||||
public abstract class QuantizedByteVectorValues extends DocIdSetIterator {
|
||||
public abstract float getScoreCorrectionConstant() throws IOException;
|
||||
|
||||
public abstract byte[] vectorValue() throws IOException;
|
||||
|
||||
/** Return the dimension of the vectors */
|
||||
public abstract int dimension();
|
||||
|
||||
/**
|
||||
* Return the number of vectors for this field.
|
||||
*
|
||||
* @return the number of vectors returned by this iterator
|
||||
*/
|
||||
public abstract int size();
|
||||
|
||||
@Override
|
||||
public final long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a {@link VectorScorer} for the given query vector.
|
||||
*
|
||||
* @param query the query vector
|
||||
* @return a {@link VectorScorer} instance
|
||||
*/
|
||||
public abstract VectorScorer vectorScorer(float[] query) throws IOException;
|
||||
}
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.apache.lucene.document.KnnByteVectorField;
|
|||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
|
@ -44,24 +43,22 @@ public class TestVectorScorer extends LuceneTestCase {
|
|||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
assert reader.leaves().size() == 1;
|
||||
LeafReaderContext context = reader.leaves().get(0);
|
||||
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
|
||||
final VectorScorer vectorScorer;
|
||||
switch (encoding) {
|
||||
case BYTE:
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new byte[] {1, 2});
|
||||
vectorScorer = context.reader().getByteVectorValues("field").scorer(new byte[] {1, 2});
|
||||
break;
|
||||
case FLOAT32:
|
||||
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
|
||||
vectorScorer = context.reader().getFloatVectorValues("field").scorer(new float[] {1, 2});
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
|
||||
}
|
||||
|
||||
DocIdSetIterator iterator = vectorScorer.iterator();
|
||||
int numDocs = 0;
|
||||
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||
if (vectorScorer.advanceExact(i)) {
|
||||
numDocs++;
|
||||
}
|
||||
while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
numDocs++;
|
||||
}
|
||||
assertEquals(3, numDocs);
|
||||
}
|
||||
|
|
|
@ -72,6 +72,7 @@ import org.apache.lucene.search.SortField;
|
|||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
@ -1128,6 +1129,11 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
public float[] vectorValue(int ord) {
|
||||
return unitVector2d(ord / (double) size, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
|
@ -1193,6 +1199,11 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
return bValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
private static float[] unitVector2d(double piRadians) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.HashSet;
|
|||
import java.util.Set;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestScalarQuantizer extends LuceneTestCase {
|
||||
|
@ -262,5 +263,10 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
curDoc = target - 1;
|
||||
return nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,10 +20,8 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
|
@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -92,14 +91,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
ParentBlockJoinByteVectorScorer vectorScorer =
|
||||
new ParentBlockJoinByteVectorScorer(
|
||||
byteVectorValues,
|
||||
acceptIterator,
|
||||
parentBitSet,
|
||||
query,
|
||||
fi.getVectorSimilarityFunction());
|
||||
VectorScorer scorer = byteVectorValues.scorer(query);
|
||||
DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer =
|
||||
new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer(
|
||||
acceptIterator, parentBitSet, scorer);
|
||||
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
|
||||
HitQueue queue = new HitQueue(queueSize, true);
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
|
@ -177,59 +172,4 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
|
|||
result = 31 * result + Arrays.hashCode(query);
|
||||
return result;
|
||||
}
|
||||
|
||||
private static class ParentBlockJoinByteVectorScorer {
|
||||
private final byte[] query;
|
||||
private final ByteVectorValues values;
|
||||
private final VectorSimilarityFunction similarity;
|
||||
private final DocIdSetIterator acceptedChildrenIterator;
|
||||
private final BitSet parentBitSet;
|
||||
private int currentParent = -1;
|
||||
private int bestChild = -1;
|
||||
private float currentScore = Float.NEGATIVE_INFINITY;
|
||||
|
||||
protected ParentBlockJoinByteVectorScorer(
|
||||
ByteVectorValues values,
|
||||
DocIdSetIterator acceptedChildrenIterator,
|
||||
BitSet parentBitSet,
|
||||
byte[] query,
|
||||
VectorSimilarityFunction similarity) {
|
||||
this.query = query;
|
||||
this.values = values;
|
||||
this.similarity = similarity;
|
||||
this.acceptedChildrenIterator = acceptedChildrenIterator;
|
||||
this.parentBitSet = parentBitSet;
|
||||
}
|
||||
|
||||
public int bestChild() {
|
||||
return bestChild;
|
||||
}
|
||||
|
||||
public int nextParent() throws IOException {
|
||||
int nextChild = acceptedChildrenIterator.docID();
|
||||
if (nextChild == -1) {
|
||||
nextChild = acceptedChildrenIterator.nextDoc();
|
||||
}
|
||||
if (nextChild == DocIdSetIterator.NO_MORE_DOCS) {
|
||||
currentParent = DocIdSetIterator.NO_MORE_DOCS;
|
||||
return currentParent;
|
||||
}
|
||||
currentScore = Float.NEGATIVE_INFINITY;
|
||||
currentParent = parentBitSet.nextSetBit(nextChild);
|
||||
do {
|
||||
values.advance(nextChild);
|
||||
float score = similarity.compare(query, values.vectorValue());
|
||||
if (score > currentScore) {
|
||||
bestChild = nextChild;
|
||||
currentScore = score;
|
||||
}
|
||||
} while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS
|
||||
&& nextChild < currentParent);
|
||||
return currentParent;
|
||||
}
|
||||
|
||||
public float score() throws IOException {
|
||||
return currentScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,11 +19,9 @@ package org.apache.lucene.search.join;
|
|||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
|
@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -92,14 +91,9 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
return NO_RESULTS;
|
||||
}
|
||||
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
DiversifyingChildrenFloatVectorScorer vectorScorer =
|
||||
new DiversifyingChildrenFloatVectorScorer(
|
||||
floatVectorValues,
|
||||
acceptIterator,
|
||||
parentBitSet,
|
||||
query,
|
||||
fi.getVectorSimilarityFunction());
|
||||
DiversifyingChildrenVectorScorer vectorScorer =
|
||||
new DiversifyingChildrenVectorScorer(
|
||||
acceptIterator, parentBitSet, floatVectorValues.scorer(query));
|
||||
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
|
||||
HitQueue queue = new HitQueue(queueSize, true);
|
||||
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
|
||||
|
@ -178,26 +172,20 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
return result;
|
||||
}
|
||||
|
||||
private static class DiversifyingChildrenFloatVectorScorer {
|
||||
private final float[] query;
|
||||
private final FloatVectorValues values;
|
||||
private final VectorSimilarityFunction similarity;
|
||||
static class DiversifyingChildrenVectorScorer {
|
||||
private final VectorScorer vectorScorer;
|
||||
private final DocIdSetIterator vectorIterator;
|
||||
private final DocIdSetIterator acceptedChildrenIterator;
|
||||
private final BitSet parentBitSet;
|
||||
private int currentParent = -1;
|
||||
private int bestChild = -1;
|
||||
private float currentScore = Float.NEGATIVE_INFINITY;
|
||||
|
||||
protected DiversifyingChildrenFloatVectorScorer(
|
||||
FloatVectorValues values,
|
||||
DocIdSetIterator acceptedChildrenIterator,
|
||||
BitSet parentBitSet,
|
||||
float[] query,
|
||||
VectorSimilarityFunction similarity) {
|
||||
this.query = query;
|
||||
this.values = values;
|
||||
this.similarity = similarity;
|
||||
protected DiversifyingChildrenVectorScorer(
|
||||
DocIdSetIterator acceptedChildrenIterator, BitSet parentBitSet, VectorScorer vectorScorer) {
|
||||
this.acceptedChildrenIterator = acceptedChildrenIterator;
|
||||
this.vectorScorer = vectorScorer;
|
||||
this.vectorIterator = vectorScorer.iterator();
|
||||
this.parentBitSet = parentBitSet;
|
||||
}
|
||||
|
||||
|
@ -217,8 +205,8 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
|
|||
currentScore = Float.NEGATIVE_INFINITY;
|
||||
currentParent = parentBitSet.nextSetBit(nextChild);
|
||||
do {
|
||||
values.advance(nextChild);
|
||||
float score = similarity.compare(query, values.vectorValue());
|
||||
vectorIterator.advance(nextChild);
|
||||
float score = vectorScorer.score();
|
||||
if (score > currentScore) {
|
||||
bestChild = nextChild;
|
||||
currentScore = score;
|
||||
|
|
|
@ -48,10 +48,12 @@ import org.apache.lucene.index.StoredFields;
|
|||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
@ -718,6 +720,116 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
}
|
||||
|
||||
public void testFloatVectorScorerIteration() throws Exception {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
if (random().nextBoolean()) {
|
||||
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
|
||||
}
|
||||
String fieldName = "field";
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
if (dimension % 2 != 0) {
|
||||
dimension++;
|
||||
}
|
||||
float[][] values = new float[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextInt(7) != 3) {
|
||||
// usually index a vector value for a doc
|
||||
values[i] = randomNormalizedVector(dimension);
|
||||
}
|
||||
add(iw, fieldName, i, values[i], similarityFunction);
|
||||
if (random().nextInt(10) == 2) {
|
||||
iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1))));
|
||||
}
|
||||
if (random().nextInt(10) == 3) {
|
||||
iw.commit();
|
||||
}
|
||||
}
|
||||
float[] vectorToScore = randomNormalizedVector(dimension);
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
VectorScorer scorer = vectorValues.scorer(vectorToScore);
|
||||
assertNotNull(scorer);
|
||||
DocIdSetIterator iterator = scorer.iterator();
|
||||
assertSame(iterator, scorer.iterator());
|
||||
assertNotSame(iterator, scorer);
|
||||
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
float score = scorer.score();
|
||||
assertTrue(score >= 0f);
|
||||
assertEquals(iterator.docID(), vectorValues.docID());
|
||||
}
|
||||
// verify that a new scorer can be obtained after iteration
|
||||
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
|
||||
assertNotNull(newScorer);
|
||||
assertNotSame(scorer, newScorer);
|
||||
assertNotSame(iterator, newScorer.iterator());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testByteVectorScorerIteration() throws Exception {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
if (random().nextBoolean()) {
|
||||
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
|
||||
}
|
||||
String fieldName = "field";
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
if (dimension % 2 != 0) {
|
||||
dimension++;
|
||||
}
|
||||
byte[][] values = new byte[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextInt(7) != 3) {
|
||||
// usually index a vector value for a doc
|
||||
values[i] = randomVector8(dimension);
|
||||
}
|
||||
add(iw, fieldName, i, values[i], similarityFunction);
|
||||
if (random().nextInt(10) == 2) {
|
||||
iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1))));
|
||||
}
|
||||
if (random().nextInt(10) == 3) {
|
||||
iw.commit();
|
||||
}
|
||||
}
|
||||
byte[] vectorToScore = randomVector8(dimension);
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
VectorScorer scorer = vectorValues.scorer(vectorToScore);
|
||||
assertNotNull(scorer);
|
||||
DocIdSetIterator iterator = scorer.iterator();
|
||||
assertSame(iterator, scorer.iterator());
|
||||
assertNotSame(iterator, scorer);
|
||||
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
float score = scorer.score();
|
||||
assertTrue(score >= 0f);
|
||||
assertEquals(iterator.docID(), vectorValues.docID());
|
||||
}
|
||||
// verify that a new scorer can be obtained after iteration
|
||||
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
|
||||
assertNotNull(newScorer);
|
||||
assertNotSame(scorer, newScorer);
|
||||
assertNotSame(iterator, newScorer.iterator());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected VectorSimilarityFunction randomSimilarity() {
|
||||
return VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
||||
|
|
Loading…
Reference in New Issue