Add new VectorScorer interface to vector value iterators (#13181)

With quantized vectors, and with current vectors, we separate out the "scoring" vs. "iteration", requiring the user to always iterate the raw vectors and provide their own similarity function.

While this is flexible, it creates frustration in:

 - Just iterating and scoring, especially since the field already has a similarity function stored...Why can't we just know which one to use and use it!
 - Iterating and scoring quantized vectors. By default it would be good to be able to iterate and score quantized vectors (e.g. without going through the HNSW graph).

This significantly hampers support for true exact kNN search.

This commit extends the vector value iterators to be able to return a scorer given some vector value (what this PR demonstrates). The scorer contains a copy of the originating iterator and allows for iteration and scoring the most optimized way the provided codec can give. 

Users can still iterate vector values directly, read them on heap, and score any way they please.
This commit is contained in:
Benjamin Trent 2024-05-09 16:30:14 -04:00 committed by GitHub
parent 8d7e4174af
commit b60e86c4b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 1166 additions and 340 deletions

View File

@ -258,6 +258,9 @@ New Features
* GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add
new Hnsw format for binary quantized vectors. (Ben Trent)
* GITHUB#13181: Add new VectorScorer interface to vector value iterators. This allows for vector codecs to supply
simpler and more optimized vector scoring when iterating vector values directly. (Ben Trent)
Improvements
---------------------

View File

@ -34,7 +34,9 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
@ -272,7 +274,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new OffHeapFloatVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice);
return new OffHeapFloatVectorValues(
fieldEntry.dimension, fieldEntry.ordToDoc, fieldEntry.similarityFunction, bytesSlice);
}
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
@ -359,14 +362,20 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
final int byteSize;
int lastOrd = -1;
final float[] value;
final VectorSimilarityFunction similarityFunction;
int ord = -1;
int doc = -1;
OffHeapFloatVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) {
OffHeapFloatVectorValues(
int dimension,
int[] ordToDoc,
VectorSimilarityFunction similarityFunction,
IndexInput dataIn) {
this.dimension = dimension;
this.ordToDoc = ordToDoc;
this.dataIn = dataIn;
this.similarityFunction = similarityFunction;
byteSize = Float.BYTES * dimension;
value = new float[dimension];
@ -420,7 +429,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone());
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
}
@Override
@ -433,6 +442,22 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
lastOrd = targetOrd;
return value;
}
@Override
public VectorScorer scorer(float[] target) {
OffHeapFloatVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
}
@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}
/** Read the nearest-neighbors graph from the index input */

View File

@ -35,7 +35,9 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
@ -255,7 +257,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new OffHeapFloatVectorValues(
fieldEntry.dimension, fieldEntry.size(), fieldEntry.ordToDoc, bytesSlice);
fieldEntry.dimension,
fieldEntry.size(),
fieldEntry.ordToDoc,
fieldEntry.similarityFunction,
bytesSlice);
}
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
@ -399,16 +405,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
private final IndexInput dataIn;
private final int byteSize;
private final float[] value;
private final VectorSimilarityFunction similarityFunction;
private int ord = -1;
private int doc = -1;
OffHeapFloatVectorValues(int dimension, int size, int[] ordToDoc, IndexInput dataIn) {
OffHeapFloatVectorValues(
int dimension,
int size,
int[] ordToDoc,
VectorSimilarityFunction similarityFunction,
IndexInput dataIn) {
this.dimension = dimension;
this.size = size;
this.ordToDoc = ordToDoc;
ordToDocOperator = ordToDoc == null ? IntUnaryOperator.identity() : (ord) -> ordToDoc[ord];
this.dataIn = dataIn;
this.similarityFunction = similarityFunction;
byteSize = Float.BYTES * dimension;
value = new float[dimension];
}
@ -468,7 +481,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone());
return new OffHeapFloatVectorValues(
dimension, size, ordToDoc, similarityFunction, dataIn.clone());
}
@Override
@ -477,6 +491,22 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
dataIn.readFloats(value, 0, value.length);
return value;
}
@Override
public VectorScorer scorer(float[] target) {
OffHeapFloatVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
}
@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}
/** Read the nearest-neighbors graph from the index input */

View File

@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene92;
import java.io.IOException;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
protected final VectorSimilarityFunction vectorSimilarityFunction;
;
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) {
OffHeapFloatVectorValues(
int dimension,
int size,
VectorSimilarityFunction vectorSimilarityFunction,
IndexInput slice) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
byteSize = Float.BYTES * dimension;
value = new float[dimension];
this.vectorSimilarityFunction = vectorSimilarityFunction;
}
@Override
@ -75,9 +85,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
vectorData.slice(
"vector-data", fieldEntry.vectorDataOffset(), fieldEntry.vectorDataLength());
if (fieldEntry.docsWithFieldOffset() == -1) {
return new DenseOffHeapVectorValues(fieldEntry.dimension(), fieldEntry.size(), bytesSlice);
return new DenseOffHeapVectorValues(
fieldEntry.dimension(), fieldEntry.size(), fieldEntry.similarityFunction(), bytesSlice);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice);
return new SparseOffHeapVectorValues(
fieldEntry, vectorData, fieldEntry.similarityFunction(), bytesSlice);
}
}
@ -85,8 +97,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
super(dimension, size, slice);
public DenseOffHeapVectorValues(
int dimension,
int size,
VectorSimilarityFunction vectorSimilarityFunction,
IndexInput slice) {
super(dimension, size, vectorSimilarityFunction, slice);
}
@Override
@ -115,13 +131,29 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone());
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}
private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
@ -132,10 +164,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
private final Lucene92HnswVectorsReader.FieldEntry fieldEntry;
public SparseOffHeapVectorValues(
Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice)
Lucene92HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
VectorSimilarityFunction vectorSimilarityFunction,
IndexInput slice)
throws IOException {
super(fieldEntry.dimension(), fieldEntry.size(), slice);
super(fieldEntry.dimension(), fieldEntry.size(), vectorSimilarityFunction, slice);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
@ -173,8 +208,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public OffHeapFloatVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
fieldEntry, dataIn, vectorSimilarityFunction, slice.clone());
}
@Override
@ -199,12 +235,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
};
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null);
super(dimension, 0, VectorSimilarityFunction.COSINE, null);
}
private int doc = -1;
@ -258,5 +310,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
@Override
public VectorScorer scorer(float[] query) {
return null;
}
}
}

View File

@ -22,6 +22,9 @@ import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
@ -39,12 +42,19 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
protected final VectorSimilarityFunction vectorSimilarityFunction;
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
OffHeapByteVectorValues(
int dimension,
int size,
IndexInput slice,
VectorSimilarityFunction vectorSimilarityFunction,
int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
this.vectorSimilarityFunction = vectorSimilarityFunction;
byteBuffer = ByteBuffer.allocate(byteSize);
binaryValue = byteBuffer.array();
}
@ -85,9 +95,14 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
int byteSize = fieldEntry.dimension();
if (fieldEntry.docsWithFieldOffset() == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize);
fieldEntry.dimension(),
fieldEntry.size(),
bytesSlice,
fieldEntry.similarityFunction(),
byteSize);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
return new SparseOffHeapVectorValues(
fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize);
}
}
@ -95,8 +110,13 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
public DenseOffHeapVectorValues(
int dimension,
int size,
IndexInput slice,
VectorSimilarityFunction vectorSimilarityFunction,
int byteSize) {
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
}
@Override
@ -124,14 +144,31 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public OffHeapByteVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public VectorScorer scorer(byte[] query) throws IOException {
DenseOffHeapVectorValues copy = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
@ -145,10 +182,11 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
IndexInput slice,
VectorSimilarityFunction vectorSimilarityFunction,
int byteSize)
throws IOException {
super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize);
super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
@ -186,8 +224,9 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public OffHeapByteVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
@ -212,12 +251,28 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
}
};
}
@Override
public VectorScorer scorer(byte[] query) throws IOException {
SparseOffHeapVectorValues copy = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
}
private int doc = -1;
@ -271,5 +326,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
@Override
public VectorScorer scorer(byte[] query) {
return null;
}
}
}

View File

@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene94;
import java.io.IOException;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
protected final VectorSimilarityFunction vectorSimilarityFunction;
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
OffHeapFloatVectorValues(
int dimension,
int size,
IndexInput slice,
VectorSimilarityFunction vectorSimilarityFunction,
int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
value = new float[dimension];
this.vectorSimilarityFunction = vectorSimilarityFunction;
}
@Override
@ -81,9 +91,14 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
};
if (fieldEntry.docsWithFieldOffset() == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize);
fieldEntry.dimension(),
fieldEntry.size(),
bytesSlice,
fieldEntry.similarityFunction(),
byteSize);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
return new SparseOffHeapVectorValues(
fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize);
}
}
@ -91,8 +106,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
public DenseOffHeapVectorValues(
int dimension,
int size,
IndexInput slice,
VectorSimilarityFunction vectorSimilarityFunction,
int byteSize) {
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
}
@Override
@ -120,14 +140,31 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public OffHeapFloatVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}
private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
@ -141,10 +178,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
IndexInput slice,
VectorSimilarityFunction vectorSimilarityFunction,
int byteSize)
throws IOException {
super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize);
super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength());
@ -182,8 +220,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public OffHeapFloatVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
@ -208,12 +247,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
};
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues values = this.copy();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
}
private int doc = -1;
@ -267,5 +322,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
@Override
public VectorScorer scorer(float[] query) {
return null;
}
}
}

View File

@ -253,6 +253,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
+ VectorEncoding.FLOAT32);
}
return OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction,
defaultFlatVectorScorer,
fieldEntry.ordToDocVectorValues,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
@ -274,6 +276,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.similarityFunction,
defaultFlatVectorScorer,
fieldEntry.ordToDocVectorValues,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
@ -295,6 +299,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
OffHeapFloatVectorValues vectorValues =
OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction,
defaultFlatVectorScorer,
fieldEntry.ordToDocVectorValues,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
@ -324,6 +330,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements
OffHeapByteVectorValues vectorValues =
OffHeapByteVectorValues.load(
fieldEntry.similarityFunction,
defaultFlatVectorScorer,
fieldEntry.ordToDocVectorValues,
fieldEntry.vectorEncoding,
fieldEntry.dimension,

View File

@ -133,7 +133,10 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
// build the graph using the temporary vector data
Lucene90HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors =
new Lucene90HnswVectorsReader.OffHeapFloatVectorValues(
floatVectorValues.dimension(), docIds, vectorDataInput);
floatVectorValues.dimension(),
docIds,
fieldInfo.getVectorSimilarityFunction(),
vectorDataInput);
long[] offsets = new long[docIds.length];
long vectorIndexOffset = vectorIndex.getFilePointer();

View File

@ -68,4 +68,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
@Override
public void testByteVectorScorerIteration() {
// unimplemented
}
}

View File

@ -138,7 +138,11 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
// TODO: separate random access vector values from DocIdSetIterator?
Lucene91HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors =
new Lucene91HnswVectorsReader.OffHeapFloatVectorValues(
floatVectorValues.dimension(), docsWithField.cardinality(), null, vectorDataInput);
floatVectorValues.dimension(),
docsWithField.cardinality(),
null,
fieldInfo.getVectorSimilarityFunction(),
vectorDataInput);
Lucene91OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null

View File

@ -67,4 +67,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
@Override
public void testByteVectorScorerIteration() {
// unimplemented
}
}

View File

@ -146,7 +146,10 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
// TODO: separate random access vector values from DocIdSetIterator?
OffHeapFloatVectorValues offHeapVectors =
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
floatVectorValues.dimension(), docsWithField.cardinality(), vectorDataInput);
floatVectorValues.dimension(),
docsWithField.cardinality(),
fieldInfo.getVectorSimilarityFunction(),
vectorDataInput);
OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null

View File

@ -57,4 +57,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
public void testSortedIndexBytes() throws Exception {
// unimplemented
}
@Override
public void testByteVectorScorerIteration() {
// unimplemented
}
}

View File

@ -421,6 +421,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
fieldInfo.getVectorSimilarityFunction(),
byteSize);
RandomVectorScorerSupplier scorerSupplier =
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
@ -437,6 +438,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
fieldInfo.getVectorSimilarityFunction(),
byteSize);
RandomVectorScorerSupplier scorerSupplier =
defaultFlatVectorScorer.getRandomVectorScorerSupplier(

View File

@ -70,6 +70,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private final IndexOutput meta, vectorData, vectorIndex;
private final int M;
private final int beamWidth;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
@ -437,7 +438,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
final RandomVectorScorerSupplier scorerSupplier;
switch (fieldInfo.getVectorEncoding()) {
case BYTE:
@ -448,7 +448,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize));
byteSize,
defaultFlatVectorScorer,
fieldInfo.getVectorSimilarityFunction()));
break;
case FLOAT32:
scorerSupplier =
@ -458,7 +460,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize));
byteSize,
defaultFlatVectorScorer,
fieldInfo.getVectorSimilarityFunction()));
break;
default:
throw new IllegalArgumentException(
@ -667,6 +671,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private final HnswGraphBuilder hnswGraphBuilder;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private int lastDocID = -1;
private int node = 0;
@ -697,7 +702,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier(

View File

@ -38,6 +38,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
@ -92,7 +93,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
assert fieldEntries.containsKey(fieldName) == false;
fieldEntries.put(
fieldName, new FieldEntry(dimension, vectorDataOffset, vectorDataLength, docIds));
fieldName,
new FieldEntry(
dimension,
vectorDataOffset,
vectorDataLength,
docIds,
readState.fieldInfos.fieldInfo(fieldName).getVectorSimilarityFunction()));
fieldNumber = readInt(in, FIELD_NUMBER);
}
SimpleTextUtil.checkFooter(in);
@ -275,7 +282,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
private record FieldEntry(
int dimension, long vectorDataOffset, long vectorDataLength, int[] ordToDoc) {
int dimension,
long vectorDataOffset,
long vectorDataLength,
int[] ordToDoc,
VectorSimilarityFunction similarityFunction) {
int size() {
return ordToDoc.length;
}
@ -298,6 +309,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
readAllVectors();
}
private SimpleTextFloatVectorValues(SimpleTextFloatVectorValues other) {
this.entry = other.entry;
this.in = other.in.clone();
this.values = other.values;
this.curOrd = other.curOrd;
}
@Override
public int dimension() {
return entry.dimension;
@ -340,6 +358,25 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return slowAdvance(target);
}
@Override
public VectorScorer scorer(float[] target) {
SimpleTextFloatVectorValues simpleTextFloatVectorValues =
new SimpleTextFloatVectorValues(this);
return new VectorScorer() {
@Override
public float score() throws IOException {
return entry
.similarityFunction()
.compare(simpleTextFloatVectorValues.vectorValue(), target);
}
@Override
public DocIdSetIterator iterator() {
return simpleTextFloatVectorValues;
}
};
}
private void readAllVectors() throws IOException {
for (float[] value : values) {
readVector(value);
@ -379,6 +416,15 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
readAllVectors();
}
private SimpleTextByteVectorValues(SimpleTextByteVectorValues other) {
this.entry = other.entry;
this.in = other.in.clone();
this.values = other.values;
this.binaryValue = new BytesRef(entry.dimension);
this.binaryValue.length = binaryValue.bytes.length;
this.curOrd = other.curOrd;
}
@Override
public int dimension() {
return entry.dimension;
@ -422,6 +468,24 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return slowAdvance(target);
}
@Override
public VectorScorer scorer(byte[] target) {
SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this);
return new VectorScorer() {
@Override
public float score() throws IOException {
return entry
.similarityFunction()
.compare(simpleTextByteVectorValues.vectorValue(), target);
}
@Override
public DocIdSetIterator iterator() {
return simpleTextByteVectorValues;
}
};
}
private void readAllVectors() throws IOException {
for (byte[] value : values) {
readVector(value);

View File

@ -27,6 +27,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
@ -159,6 +160,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
@ -216,6 +222,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
}
@Override
@ -354,6 +365,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
private static class BufferedByteVectorValues extends ByteVectorValues {
@ -414,5 +430,10 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Accountable;
/** Writes vectors to an index. */
@ -188,7 +189,6 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final int size;
private int docId;
VectorValuesSub current;
@ -239,6 +239,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
public int dimension() {
return subs.get(0).values.dimension();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
static class MergedByteVectorValues extends ByteVectorValues {
@ -296,6 +301,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
public int dimension() {
return subs.get(0).values.dimension();
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
}
}
}

View File

@ -19,13 +19,18 @@ package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
@ -39,14 +44,24 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
protected final VectorSimilarityFunction similarityFunction;
protected final FlatVectorsScorer flatVectorsScorer;
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
OffHeapByteVectorValues(
int dimension,
int size,
IndexInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
binaryValue = byteBuffer.array();
this.similarityFunction = similarityFunction;
this.flatVectorsScorer = flatVectorsScorer;
}
@Override
@ -79,6 +94,8 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
}
public static OffHeapByteVectorValues load(
VectorSimilarityFunction vectorSimilarityFunction,
FlatVectorsScorer flatVectorsScorer,
OrdToDocDISIReaderConfiguration configuration,
VectorEncoding vectorEncoding,
int dimension,
@ -87,14 +104,26 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
IndexInput vectorData)
throws IOException {
if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(dimension);
return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction);
}
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, dimension);
return new DenseOffHeapVectorValues(
dimension,
configuration.size,
bytesSlice,
dimension,
flatVectorsScorer,
vectorSimilarityFunction);
} else {
return new SparseOffHeapVectorValues(
configuration, vectorData, bytesSlice, dimension, dimension);
configuration,
vectorData,
bytesSlice,
dimension,
dimension,
flatVectorsScorer,
vectorSimilarityFunction);
}
}
@ -106,8 +135,14 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
public DenseOffHeapVectorValues(
int dimension,
int size,
IndexInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction vectorSimilarityFunction) {
super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction);
}
@Override
@ -136,13 +171,32 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public VectorScorer scorer(byte[] query) throws IOException {
DenseOffHeapVectorValues copy = copy();
RandomVectorScorer scorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return scorer.score(copy.doc);
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
@ -157,10 +211,18 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
IndexInput dataIn,
IndexInput slice,
int dimension,
int byteSize)
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction vectorSimilarityFunction)
throws IOException {
super(dimension, configuration.size, slice, byteSize);
super(
dimension,
configuration.size,
slice,
byteSize,
flatVectorsScorer,
vectorSimilarityFunction);
this.configuration = configuration;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength);
@ -200,7 +262,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
configuration, dataIn, slice.clone(), dimension, byteSize);
configuration,
dataIn,
slice.clone(),
dimension,
byteSize,
flatVectorsScorer,
similarityFunction);
}
@Override
@ -225,12 +293,33 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
}
};
}
@Override
public VectorScorer scorer(byte[] query) throws IOException {
SparseOffHeapVectorValues copy = copy();
RandomVectorScorer scorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return scorer.score(copy.disi.index());
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
public EmptyOffHeapVectorValues(
int dimension,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction vectorSimilarityFunction) {
super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction);
}
private int doc = -1;
@ -284,5 +373,10 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
@Override
public VectorScorer scorer(byte[] query) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -18,13 +18,18 @@
package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
@ -37,12 +42,22 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
protected final VectorSimilarityFunction similarityFunction;
protected final FlatVectorsScorer flatVectorsScorer;
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
OffHeapFloatVectorValues(
int dimension,
int size,
IndexInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
this.similarityFunction = similarityFunction;
this.flatVectorsScorer = flatVectorsScorer;
value = new float[dimension];
}
@ -73,6 +88,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
public static OffHeapFloatVectorValues load(
VectorSimilarityFunction vectorSimilarityFunction,
FlatVectorsScorer flatVectorsScorer,
OrdToDocDISIReaderConfiguration configuration,
VectorEncoding vectorEncoding,
int dimension,
@ -81,15 +98,27 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
IndexInput vectorData)
throws IOException {
if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) {
return new EmptyOffHeapVectorValues(dimension);
return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction);
}
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
int byteSize = dimension * Float.BYTES;
if (configuration.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
return new DenseOffHeapVectorValues(
dimension,
configuration.size,
bytesSlice,
byteSize,
flatVectorsScorer,
vectorSimilarityFunction);
} else {
return new SparseOffHeapVectorValues(
configuration, vectorData, bytesSlice, dimension, byteSize);
configuration,
vectorData,
bytesSlice,
dimension,
byteSize,
flatVectorsScorer,
vectorSimilarityFunction);
}
}
@ -101,8 +130,14 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
public DenseOffHeapVectorValues(
int dimension,
int size,
IndexInput slice,
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction);
}
@Override
@ -131,13 +166,32 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues copy = copy();
RandomVectorScorer randomVectorScorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return randomVectorScorer.score(copy.doc);
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues {
@ -152,10 +206,12 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
IndexInput dataIn,
IndexInput slice,
int dimension,
int byteSize)
int byteSize,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction)
throws IOException {
super(dimension, configuration.size, slice, byteSize);
super(dimension, configuration.size, slice, byteSize, flatVectorsScorer, similarityFunction);
this.configuration = configuration;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength);
@ -195,7 +251,13 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
configuration, dataIn, slice.clone(), dimension, byteSize);
configuration,
dataIn,
slice.clone(),
dimension,
byteSize,
flatVectorsScorer,
similarityFunction);
}
@Override
@ -220,12 +282,33 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
};
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues copy = copy();
RandomVectorScorer randomVectorScorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return randomVectorScorer.score(copy.disi.index());
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
public EmptyOffHeapVectorValues(
int dimension,
FlatVectorsScorer flatVectorsScorer,
VectorSimilarityFunction similarityFunction) {
super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction);
}
private int doc = -1;
@ -256,17 +339,17 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public int advance(int target) throws IOException {
public int advance(int target) {
return doc = NO_MORE_DOCS;
}
@Override
public EmptyOffHeapVectorValues copy() throws IOException {
public EmptyOffHeapVectorValues copy() {
throw new UnsupportedOperationException();
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
public float[] vectorValue(int targetOrd) {
throw new UnsupportedOperationException();
}
@ -279,5 +362,10 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
@Override
public VectorScorer scorer(float[] query) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -185,6 +185,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
+ VectorEncoding.FLOAT32);
}
return OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
@ -206,6 +208,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
@ -223,6 +227,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction,
OffHeapFloatVectorValues.load(
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
@ -241,6 +247,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
return vectorScorer.getRandomVectorScorer(
fieldEntry.similarityFunction,
OffHeapByteVectorValues.load(
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,

View File

@ -314,14 +314,18 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
fieldInfo.getVectorDimension() * Byte.BYTES));
fieldInfo.getVectorDimension() * Byte.BYTES,
vectorsScorer,
fieldInfo.getVectorSimilarityFunction()));
case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
fieldInfo.getVectorDimension() * Float.BYTES));
fieldInfo.getVectorDimension() * Float.BYTES,
vectorsScorer,
fieldInfo.getVectorSimilarityFunction()));
};
return new FlatCloseableRandomVectorScorerSupplier(
() -> {

View File

@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
@ -164,7 +165,24 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return rawVectorsReader.getFloatVectorValues(field);
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.scalarQuantizer,
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
return new QuantizedVectorValues(rawVectorValues, quantizedByteVectorValues);
}
@Override
@ -227,6 +245,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.scalarQuantizer,
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
@ -282,6 +302,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.scalarQuantizer,
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
@ -369,4 +391,56 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc);
}
}
private static final class QuantizedVectorValues extends FloatVectorValues {
private final FloatVectorValues rawVectorValues;
private final OffHeapQuantizedByteVectorValues quantizedVectorValues;
QuantizedVectorValues(
FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) {
this.rawVectorValues = rawVectorValues;
this.quantizedVectorValues = quantizedVectorValues;
}
@Override
public int dimension() {
return rawVectorValues.dimension();
}
@Override
public int size() {
return rawVectorValues.size();
}
@Override
public float[] vectorValue() throws IOException {
return rawVectorValues.vectorValue();
}
@Override
public int docID() {
return rawVectorValues.docID();
}
@Override
public int nextDoc() throws IOException {
int rawDocId = rawVectorValues.nextDoc();
int quantizedDocId = quantizedVectorValues.nextDoc();
assert rawDocId == quantizedDocId;
return quantizedDocId;
}
@Override
public int advance(int target) throws IOException {
int rawDocId = rawVectorValues.advance(target);
int quantizedDocId = quantizedVectorValues.advance(target);
assert rawDocId == quantizedDocId;
return quantizedDocId;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
return quantizedVectorValues.vectorScorer(query);
}
}
}

View File

@ -49,6 +49,7 @@ import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
@ -526,6 +527,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
docsWithField.cardinality(),
mergedQuantizationState,
compress,
fieldInfo.getVectorSimilarityFunction(),
vectorsScorer,
quantizationDataInput)));
} finally {
if (success == false) {
@ -890,6 +893,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
curDoc = target;
return docID();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
@ -1013,6 +1021,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
public float getScoreCorrectionConstant() throws IOException {
return current.values.getScoreCorrectionConstant();
}
@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
}
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
@ -1082,6 +1095,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
return doc;
}
@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
private void quantize() throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
@ -1182,5 +1200,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
public int advance(int target) throws IOException {
return in.advance(target);
}
@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
}
}

View File

@ -19,10 +19,15 @@ package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
@ -39,6 +44,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
protected final int size;
protected final int numBytes;
protected final ScalarQuantizer scalarQuantizer;
protected final VectorSimilarityFunction similarityFunction;
protected final FlatVectorsScorer vectorsScorer;
protected final boolean compress;
protected final IndexInput slice;
@ -85,6 +92,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
int dimension,
int size,
ScalarQuantizer scalarQuantizer,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
boolean compress,
IndexInput slice) {
this.dimension = dimension;
@ -100,6 +109,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
this.byteSize = this.numBytes + Float.BYTES;
byteBuffer = ByteBuffer.allocate(dimension);
binaryValue = byteBuffer.array();
this.similarityFunction = similarityFunction;
this.vectorsScorer = vectorsScorer;
}
@Override
@ -160,22 +171,39 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
int dimension,
int size,
ScalarQuantizer scalarQuantizer,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
boolean compress,
long quantizedVectorDataOffset,
long quantizedVectorDataLength,
IndexInput vectorData)
throws IOException {
if (configuration.isEmpty()) {
return new EmptyOffHeapVectorValues(dimension);
return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
}
IndexInput bytesSlice =
vectorData.slice(
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, bytesSlice);
return new DenseOffHeapVectorValues(
dimension,
size,
scalarQuantizer,
compress,
similarityFunction,
vectorsScorer,
bytesSlice);
} else {
return new SparseOffHeapVectorValues(
configuration, dimension, size, scalarQuantizer, compress, vectorData, bytesSlice);
configuration,
dimension,
size,
scalarQuantizer,
compress,
vectorData,
similarityFunction,
vectorsScorer,
bytesSlice);
}
}
@ -192,8 +220,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
int size,
ScalarQuantizer scalarQuantizer,
boolean compress,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
IndexInput slice) {
super(dimension, size, scalarQuantizer, compress, slice);
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
}
@Override
@ -223,13 +253,37 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, scalarQuantizer, compress, slice.clone());
dimension,
size,
scalarQuantizer,
compress,
similarityFunction,
vectorsScorer,
slice.clone());
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
DenseOffHeapVectorValues copy = copy();
RandomVectorScorer vectorScorer =
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorScorer.score(copy.doc);
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
@ -246,9 +300,11 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
ScalarQuantizer scalarQuantizer,
boolean compress,
IndexInput dataIn,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer,
IndexInput slice)
throws IOException {
super(dimension, size, scalarQuantizer, compress, slice);
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
this.configuration = configuration;
this.dataIn = dataIn;
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
@ -279,7 +335,15 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
configuration, dimension, size, scalarQuantizer, compress, dataIn, slice.clone());
configuration,
dimension,
size,
scalarQuantizer,
compress,
dataIn,
similarityFunction,
vectorsScorer,
slice.clone());
}
@Override
@ -304,12 +368,40 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
}
};
}
@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
SparseOffHeapVectorValues copy = copy();
RandomVectorScorer vectorScorer =
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorScorer.score(copy.disi.index());
}
@Override
public DocIdSetIterator iterator() {
return copy;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), false, null);
public EmptyOffHeapVectorValues(
int dimension,
VectorSimilarityFunction similarityFunction,
FlatVectorsScorer vectorsScorer) {
super(
dimension,
0,
new ScalarQuantizer(-1, 1, (byte) 7),
similarityFunction,
vectorsScorer,
false,
null);
}
private int doc = -1;
@ -363,5 +455,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
public Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
@Override
public VectorScorer vectorScorer(float[] target) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
/**
* This class provides access to per-document floating point vector values indexed as {@link
@ -75,4 +76,14 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
+ ")");
}
}
/**
* Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not
* the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and
* iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}.
*
* @param query the query vector
* @return a {@link VectorScorer} instance
*/
public abstract VectorScorer scorer(byte[] query) throws IOException;
}

View File

@ -22,6 +22,7 @@ import org.apache.lucene.index.FilterLeafReader.FilterTerms;
import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
@ -476,6 +477,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return vectorValues.size();
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
return vectorValues.scorer(target);
}
/**
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
* if {@link Thread#interrupted()} returns true.
@ -543,6 +549,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return vectorValues.vectorValue();
}
@Override
public VectorScorer scorer(byte[] target) throws IOException {
return vectorValues.scorer(target);
}
/**
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
* if {@link Thread#interrupted()} returns true.

View File

@ -19,6 +19,7 @@ package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
/**
* This class provides access to per-document floating point vector values indexed as {@link
@ -75,4 +76,15 @@ public abstract class FloatVectorValues extends DocIdSetIterator {
+ ")");
}
}
/**
* Return a {@link VectorScorer} for the given query vector and the current {@link
* FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for
* this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the
* iteration of this {@link FloatVectorValues}.
*
* @param query the query vector
* @return a {@link VectorScorer} instance
*/
public abstract VectorScorer scorer(float[] query) throws IOException;
}

View File

@ -36,6 +36,7 @@ import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues;
import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
@ -883,6 +884,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
public int advance(int target) throws IOException {
return mergedIterator.advance(target);
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
};
}
@ -937,6 +943,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
public int advance(int target) throws IOException {
return mergedIterator.advance(target);
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
};
}

View File

@ -35,6 +35,7 @@ import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOSupplier;
@ -266,6 +267,11 @@ public final class SortingCodecReader extends FilterCodecReader {
}
return docId = docsWithField.nextSetBit(target);
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
private static class SortingByteVectorValues extends ByteVectorValues {
@ -320,6 +326,11 @@ public final class SortingCodecReader extends FilterCodecReader {
}
return docId = docsWithField.nextSetBit(target);
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
}
/**

View File

@ -205,17 +205,17 @@ abstract class AbstractKnnVectorQuery extends Query {
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
DocIdSetIterator vectorIterator = vectorScorer.iterator();
DocIdSetIterator conjunction =
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;
assert vectorIterator.docID() == doc;
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;

View File

@ -19,6 +19,7 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
@ -83,7 +84,10 @@ abstract class AbstractVectorSimilarityQuery extends Query {
VectorScorer scorer = createVectorScorer(context);
if (scorer == null) {
return Explanation.noMatch("Not indexed as the correct vector field");
} else if (scorer.advanceExact(doc)) {
}
DocIdSetIterator iterator = scorer.iterator();
int docId = iterator.advance(doc);
if (docId == doc) {
float score = scorer.score();
if (score >= resultSimilarity) {
return Explanation.match(boost * score, "Score above threshold");
@ -256,15 +260,15 @@ abstract class AbstractVectorSimilarityQuery extends Query {
DocIdSetIterator acceptDocs,
float threshold) {
float[] cachedScore = new float[1];
DocIdSetIterator vectorIterator = scorer.iterator();
DocIdSetIterator conjunction =
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptDocs), List.of());
DocIdSetIterator iterator =
new FilteredDocIdSetIterator(acceptDocs) {
new FilteredDocIdSetIterator(conjunction) {
@Override
protected boolean match(int doc) throws IOException {
// Advance the scorer
if (!scorer.advanceExact(doc)) {
return false;
}
assert doc == vectorIterator.docID();
// Compute the dot product
float score = scorer.score();
cachedScore[0] = score * boost;

View File

@ -21,9 +21,8 @@ import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
/**
@ -98,12 +97,11 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
@Override
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
@SuppressWarnings("resource")
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) {
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field);
if (vectorValues == null) {
return null;
}
return VectorScorer.create(context, fi, target);
return vectorValues.scorer(target);
}
@Override

View File

@ -22,7 +22,6 @@ import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
@ -37,26 +36,13 @@ class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
}
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
ByteVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
return null;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return function.compare(queryVector, vectorValues.vectorValue());
}
@Override
public boolean advanceExact(int doc) throws IOException {
return doc >= vectorValues.docID()
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
}
};
return vectorValues.scorer(queryVector);
}
@Override

View File

@ -21,9 +21,8 @@ import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
@ -100,11 +99,11 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
@Override
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
@SuppressWarnings("resource")
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field);
if (vectorValues == null) {
return null;
}
return VectorScorer.create(context, fi, target);
return vectorValues.scorer(target);
}
@Override

View File

@ -22,7 +22,6 @@ import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
@ -44,22 +43,32 @@ class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
FloatVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
private final VectorScorer scorer = vectorValues.scorer(queryVector);
private final DocIdSetIterator iterator = scorer.iterator();
@Override
public double doubleValue() throws IOException {
return function.compare(queryVector, vectorValues.vectorValue());
return scorer.score();
}
@Override
public boolean advanceExact(int doc) throws IOException {
return doc >= vectorValues.docID()
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
}
};
}
@Override
public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues == null) {
FloatVectorValues.checkField(ctx.reader(), fieldName);
return null;
}
return vectorValues.scorer(queryVector);
}
@Override
public int hashCode() {
return Objects.hash(fieldName, Arrays.hashCode(queryVector));

View File

@ -98,7 +98,12 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
return VectorScorer.create(context, fi, target);
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field);
if (vectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return null;
}
return vectorValues.scorer(target);
}
@Override

View File

@ -99,7 +99,12 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
return VectorScorer.create(context, fi, target);
FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field);
if (vectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
return null;
}
return vectorValues.scorer(target);
}
@Override

View File

@ -17,117 +17,25 @@
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* Computes the similarity score between a given query vector and different document vectors. This
* is primarily used by {@link KnnFloatVectorQuery} to run an exact, exhaustive search over the
* vectors.
* is used for exact searching and scoring
*
* @lucene.experimental
*/
abstract class VectorScorer {
protected final VectorSimilarityFunction similarity;
public interface VectorScorer {
/**
* Create a new vector scorer instance.
* Compute the score for the current document ID.
*
* @param context the reader context
* @param fi the FieldInfo for the field containing document vectors
* @param query the query vector to compute the similarity for
* @return the score for the current document ID
* @throws IOException if an exception occurs during score computation
*/
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
throws IOException {
FloatVectorValues values = context.reader().getFloatVectorValues(fi.name);
if (values == null) {
FloatVectorValues.checkField(context.reader(), fi.name);
return null;
}
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new FloatVectorScorer(values, query, similarity);
}
float score() throws IOException;
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
throws IOException {
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
if (values == null) {
ByteVectorValues.checkField(context.reader(), fi.name);
return null;
}
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new ByteVectorScorer(values, query, similarity);
}
VectorScorer(VectorSimilarityFunction similarity) {
this.similarity = similarity;
}
/** Compute the similarity score for the current document. */
abstract float score() throws IOException;
abstract boolean advanceExact(int doc) throws IOException;
private static class ByteVectorScorer extends VectorScorer {
private final byte[] query;
private final ByteVectorValues values;
protected ByteVectorScorer(
ByteVectorValues values, byte[] query, VectorSimilarityFunction similarity) {
super(similarity);
this.values = values;
this.query = query;
}
/**
* Advance the instance to the given document ID and return true if there is a value for that
* document.
*/
@Override
public boolean advanceExact(int doc) throws IOException {
int vectorDoc = values.docID();
if (vectorDoc < doc) {
vectorDoc = values.advance(doc);
}
return vectorDoc == doc;
}
@Override
public float score() throws IOException {
assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned";
return similarity.compare(query, values.vectorValue());
}
}
private static class FloatVectorScorer extends VectorScorer {
private final float[] query;
private final FloatVectorValues values;
protected FloatVectorScorer(
FloatVectorValues values, float[] query, VectorSimilarityFunction similarity) {
super(similarity);
this.query = query;
this.values = values;
}
/**
* Advance the instance to the given document ID and return true if there is a value for that
* document.
*/
@Override
public boolean advanceExact(int doc) throws IOException {
int vectorDoc = values.docID();
if (vectorDoc < doc) {
vectorDoc = values.advance(doc);
}
return vectorDoc == doc;
}
@Override
public float score() throws IOException {
assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned";
return similarity.compare(query, values.vectorValue());
}
}
/**
* @return a {@link DocIdSetIterator} over the documents.
*/
DocIdSetIterator iterator();
}

View File

@ -33,8 +33,26 @@ abstract class VectorSimilarityValuesSource extends DoubleValuesSource {
}
@Override
public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores)
throws IOException;
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
VectorScorer scorer = getScorer(ctx);
if (scorer == null) {
return DoubleValues.EMPTY;
}
DocIdSetIterator iterator = scorer.iterator();
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return scorer.score();
}
@Override
public boolean advanceExact(int doc) throws IOException {
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
}
};
}
protected abstract VectorScorer getScorer(LeafReaderContext ctx) throws IOException;
@Override
public boolean needsScores() {

View File

@ -18,6 +18,8 @@ package org.apache.lucene.util.quantization;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
/**
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
@ -25,6 +27,31 @@ import org.apache.lucene.index.ByteVectorValues;
*
* @lucene.experimental
*/
public abstract class QuantizedByteVectorValues extends ByteVectorValues {
public abstract class QuantizedByteVectorValues extends DocIdSetIterator {
public abstract float getScoreCorrectionConstant() throws IOException;
public abstract byte[] vectorValue() throws IOException;
/** Return the dimension of the vectors */
public abstract int dimension();
/**
* Return the number of vectors for this field.
*
* @return the number of vectors returned by this iterator
*/
public abstract int size();
@Override
public final long cost() {
return size();
}
/**
* Return a {@link VectorScorer} for the given query vector.
*
* @param query the query vector
* @return a {@link VectorScorer} instance
*/
public abstract VectorScorer vectorScorer(float[] query) throws IOException;
}

View File

@ -26,7 +26,6 @@ import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
@ -44,24 +43,22 @@ public class TestVectorScorer extends LuceneTestCase {
IndexReader reader = DirectoryReader.open(indexStore)) {
assert reader.leaves().size() == 1;
LeafReaderContext context = reader.leaves().get(0);
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
final VectorScorer vectorScorer;
switch (encoding) {
case BYTE:
vectorScorer = VectorScorer.create(context, fieldInfo, new byte[] {1, 2});
vectorScorer = context.reader().getByteVectorValues("field").scorer(new byte[] {1, 2});
break;
case FLOAT32:
vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
vectorScorer = context.reader().getFloatVectorValues("field").scorer(new float[] {1, 2});
break;
default:
throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
}
DocIdSetIterator iterator = vectorScorer.iterator();
int numDocs = 0;
for (int i = 0; i < reader.maxDoc(); i++) {
if (vectorScorer.advanceExact(i)) {
numDocs++;
}
while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
numDocs++;
}
assertEquals(3, numDocs);
}

View File

@ -72,6 +72,7 @@ import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@ -1128,6 +1129,11 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public float[] vectorValue(int ord) {
return unitVector2d(ord / (double) size, value);
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
@ -1193,6 +1199,11 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
return bValue;
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
}
private static float[] unitVector2d(double piRadians) {

View File

@ -23,6 +23,7 @@ import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestScalarQuantizer extends LuceneTestCase {
@ -262,5 +263,10 @@ public class TestScalarQuantizer extends LuceneTestCase {
curDoc = target - 1;
return nextDoc();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -20,10 +20,8 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
@ -92,14 +91,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
return NO_RESULTS;
}
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
ParentBlockJoinByteVectorScorer vectorScorer =
new ParentBlockJoinByteVectorScorer(
byteVectorValues,
acceptIterator,
parentBitSet,
query,
fi.getVectorSimilarityFunction());
VectorScorer scorer = byteVectorValues.scorer(query);
DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer =
new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer(
acceptIterator, parentBitSet, scorer);
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
@ -177,59 +172,4 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
result = 31 * result + Arrays.hashCode(query);
return result;
}
private static class ParentBlockJoinByteVectorScorer {
private final byte[] query;
private final ByteVectorValues values;
private final VectorSimilarityFunction similarity;
private final DocIdSetIterator acceptedChildrenIterator;
private final BitSet parentBitSet;
private int currentParent = -1;
private int bestChild = -1;
private float currentScore = Float.NEGATIVE_INFINITY;
protected ParentBlockJoinByteVectorScorer(
ByteVectorValues values,
DocIdSetIterator acceptedChildrenIterator,
BitSet parentBitSet,
byte[] query,
VectorSimilarityFunction similarity) {
this.query = query;
this.values = values;
this.similarity = similarity;
this.acceptedChildrenIterator = acceptedChildrenIterator;
this.parentBitSet = parentBitSet;
}
public int bestChild() {
return bestChild;
}
public int nextParent() throws IOException {
int nextChild = acceptedChildrenIterator.docID();
if (nextChild == -1) {
nextChild = acceptedChildrenIterator.nextDoc();
}
if (nextChild == DocIdSetIterator.NO_MORE_DOCS) {
currentParent = DocIdSetIterator.NO_MORE_DOCS;
return currentParent;
}
currentScore = Float.NEGATIVE_INFINITY;
currentParent = parentBitSet.nextSetBit(nextChild);
do {
values.advance(nextChild);
float score = similarity.compare(query, values.vectorValue());
if (score > currentScore) {
bestChild = nextChild;
currentScore = score;
}
} while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS
&& nextChild < currentParent);
return currentParent;
}
public float score() throws IOException {
return currentScore;
}
}
}

View File

@ -19,11 +19,9 @@ package org.apache.lucene.search.join;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
@ -92,14 +91,9 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
return NO_RESULTS;
}
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
DiversifyingChildrenFloatVectorScorer vectorScorer =
new DiversifyingChildrenFloatVectorScorer(
floatVectorValues,
acceptIterator,
parentBitSet,
query,
fi.getVectorSimilarityFunction());
DiversifyingChildrenVectorScorer vectorScorer =
new DiversifyingChildrenVectorScorer(
acceptIterator, parentBitSet, floatVectorValues.scorer(query));
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
@ -178,26 +172,20 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
return result;
}
private static class DiversifyingChildrenFloatVectorScorer {
private final float[] query;
private final FloatVectorValues values;
private final VectorSimilarityFunction similarity;
static class DiversifyingChildrenVectorScorer {
private final VectorScorer vectorScorer;
private final DocIdSetIterator vectorIterator;
private final DocIdSetIterator acceptedChildrenIterator;
private final BitSet parentBitSet;
private int currentParent = -1;
private int bestChild = -1;
private float currentScore = Float.NEGATIVE_INFINITY;
protected DiversifyingChildrenFloatVectorScorer(
FloatVectorValues values,
DocIdSetIterator acceptedChildrenIterator,
BitSet parentBitSet,
float[] query,
VectorSimilarityFunction similarity) {
this.query = query;
this.values = values;
this.similarity = similarity;
protected DiversifyingChildrenVectorScorer(
DocIdSetIterator acceptedChildrenIterator, BitSet parentBitSet, VectorScorer vectorScorer) {
this.acceptedChildrenIterator = acceptedChildrenIterator;
this.vectorScorer = vectorScorer;
this.vectorIterator = vectorScorer.iterator();
this.parentBitSet = parentBitSet;
}
@ -217,8 +205,8 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
currentScore = Float.NEGATIVE_INFINITY;
currentParent = parentBitSet.nextSetBit(nextChild);
do {
values.advance(nextChild);
float score = similarity.compare(query, values.vectorValue());
vectorIterator.advance(nextChild);
float score = vectorScorer.score();
if (score > currentScore) {
bestChild = nextChild;
currentScore = score;

View File

@ -48,10 +48,12 @@ import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.tests.util.TestUtil;
@ -718,6 +720,116 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
public void testFloatVectorScorerIteration() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
if (random().nextBoolean()) {
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
}
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
values[i] = randomNormalizedVector(dimension);
}
add(iw, fieldName, i, values[i], similarityFunction);
if (random().nextInt(10) == 2) {
iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1))));
}
if (random().nextInt(10) == 3) {
iw.commit();
}
}
float[] vectorToScore = randomNormalizedVector(dimension);
try (IndexReader reader = DirectoryReader.open(iw)) {
for (LeafReaderContext ctx : reader.leaves()) {
FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
VectorScorer scorer = vectorValues.scorer(vectorToScore);
assertNotNull(scorer);
DocIdSetIterator iterator = scorer.iterator();
assertSame(iterator, scorer.iterator());
assertNotSame(iterator, scorer);
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
float score = scorer.score();
assertTrue(score >= 0f);
assertEquals(iterator.docID(), vectorValues.docID());
}
// verify that a new scorer can be obtained after iteration
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
assertNotNull(newScorer);
assertNotSame(scorer, newScorer);
assertNotSame(iterator, newScorer.iterator());
}
}
}
}
public void testByteVectorScorerIteration() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
if (random().nextBoolean()) {
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
}
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
byte[][] values = new byte[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
values[i] = randomVector8(dimension);
}
add(iw, fieldName, i, values[i], similarityFunction);
if (random().nextInt(10) == 2) {
iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1))));
}
if (random().nextInt(10) == 3) {
iw.commit();
}
}
byte[] vectorToScore = randomVector8(dimension);
try (IndexReader reader = DirectoryReader.open(iw)) {
for (LeafReaderContext ctx : reader.leaves()) {
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
VectorScorer scorer = vectorValues.scorer(vectorToScore);
assertNotNull(scorer);
DocIdSetIterator iterator = scorer.iterator();
assertSame(iterator, scorer.iterator());
assertNotSame(iterator, scorer);
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
float score = scorer.score();
assertTrue(score >= 0f);
assertEquals(iterator.docID(), vectorValues.docID());
}
// verify that a new scorer can be obtained after iteration
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
assertNotNull(newScorer);
assertNotSame(scorer, newScorer);
assertNotSame(iterator, newScorer.iterator());
}
}
}
}
protected VectorSimilarityFunction randomSimilarity() {
return VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];