From b60e86c4b91e6987492d65e51185d80ff191a478 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 9 May 2024 16:30:14 -0400 Subject: [PATCH] Add new VectorScorer interface to vector value iterators (#13181) With quantized vectors, and with current vectors, we separate out the "scoring" vs. "iteration", requiring the user to always iterate the raw vectors and provide their own similarity function. While this is flexible, it creates frustration in: - Just iterating and scoring, especially since the field already has a similarity function stored...Why can't we just know which one to use and use it! - Iterating and scoring quantized vectors. By default it would be good to be able to iterate and score quantized vectors (e.g. without going through the HNSW graph). This significantly hampers support for true exact kNN search. This commit extends the vector value iterators to be able to return a scorer given some vector value (what this PR demonstrates). The scorer contains a copy of the originating iterator and allows for iteration and scoring the most optimized way the provided codec can give. Users can still iterate vector values directly, read them on heap, and score any way they please. --- lucene/CHANGES.txt | 3 + .../lucene90/Lucene90HnswVectorsReader.java | 31 ++++- .../lucene91/Lucene91HnswVectorsReader.java | 36 +++++- .../lucene92/OffHeapFloatVectorValues.java | 79 ++++++++++-- .../lucene94/OffHeapByteVectorValues.java | 82 ++++++++++-- .../lucene94/OffHeapFloatVectorValues.java | 82 ++++++++++-- .../lucene95/Lucene95HnswVectorsReader.java | 8 ++ .../lucene90/Lucene90HnswVectorsWriter.java | 5 +- .../TestLucene90HnswVectorsFormat.java | 5 + .../lucene91/Lucene91HnswVectorsWriter.java | 6 +- .../TestLucene91HnswVectorsFormat.java | 5 + .../lucene92/Lucene92HnswVectorsWriter.java | 5 +- .../TestLucene92HnswVectorsFormat.java | 5 + .../lucene94/Lucene94HnswVectorsWriter.java | 2 + .../lucene95/Lucene95HnswVectorsWriter.java | 12 +- .../SimpleTextKnnVectorsReader.java | 68 +++++++++- .../codecs/BufferingKnnVectorsWriter.java | 21 ++++ .../lucene/codecs/KnnVectorsWriter.java | 12 +- .../lucene95/OffHeapByteVectorValues.java | 118 ++++++++++++++++-- .../lucene95/OffHeapFloatVectorValues.java | 118 +++++++++++++++--- .../lucene99/Lucene99FlatVectorsReader.java | 8 ++ .../lucene99/Lucene99FlatVectorsWriter.java | 8 +- .../Lucene99ScalarQuantizedVectorsReader.java | 76 ++++++++++- .../Lucene99ScalarQuantizedVectorsWriter.java | 23 ++++ .../OffHeapQuantizedByteVectorValues.java | 115 +++++++++++++++-- .../apache/lucene/index/ByteVectorValues.java | 11 ++ .../lucene/index/ExitableDirectoryReader.java | 11 ++ .../lucene/index/FloatVectorValues.java | 12 ++ .../SlowCompositeCodecReaderWrapper.java | 11 ++ .../lucene/index/SortingCodecReader.java | 11 ++ .../lucene/search/AbstractKnnVectorQuery.java | 10 +- .../search/AbstractVectorSimilarityQuery.java | 16 ++- .../search/ByteVectorSimilarityQuery.java | 10 +- .../ByteVectorSimilarityValuesSource.java | 20 +-- .../search/FloatVectorSimilarityQuery.java | 9 +- .../FloatVectorSimilarityValuesSource.java | 21 +++- .../lucene/search/KnnByteVectorQuery.java | 7 +- .../lucene/search/KnnFloatVectorQuery.java | 7 +- .../apache/lucene/search/VectorScorer.java | 116 ++--------------- .../search/VectorSimilarityValuesSource.java | 22 +++- .../QuantizedByteVectorValues.java | 29 ++++- .../lucene/search/TestVectorScorer.java | 13 +- .../lucene/util/hnsw/HnswGraphTestCase.java | 11 ++ .../quantization/TestScalarQuantizer.java | 6 + ...iversifyingChildrenByteKnnVectorQuery.java | 70 +---------- ...versifyingChildrenFloatKnnVectorQuery.java | 38 ++---- .../index/BaseKnnVectorsFormatTestCase.java | 112 +++++++++++++++++ 47 files changed, 1166 insertions(+), 340 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 35f52d1c796..3650c13ab39 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -258,6 +258,9 @@ New Features * GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add new Hnsw format for binary quantized vectors. (Ben Trent) +* GITHUB#13181: Add new VectorScorer interface to vector value iterators. This allows for vector codecs to supply + simpler and more optimized vector scoring when iterating vector values directly. (Ben Trent) + Improvements --------------------- diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index ea63c926ac1..6518b4aeab7 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -34,7 +34,9 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -272,7 +274,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { throws IOException { IndexInput bytesSlice = vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); - return new OffHeapFloatVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice); + return new OffHeapFloatVectorValues( + fieldEntry.dimension, fieldEntry.ordToDoc, fieldEntry.similarityFunction, bytesSlice); } private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) { @@ -359,14 +362,20 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { final int byteSize; int lastOrd = -1; final float[] value; + final VectorSimilarityFunction similarityFunction; int ord = -1; int doc = -1; - OffHeapFloatVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) { + OffHeapFloatVectorValues( + int dimension, + int[] ordToDoc, + VectorSimilarityFunction similarityFunction, + IndexInput dataIn) { this.dimension = dimension; this.ordToDoc = ordToDoc; this.dataIn = dataIn; + this.similarityFunction = similarityFunction; byteSize = Float.BYTES * dimension; value = new float[dimension]; @@ -420,7 +429,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { @Override public OffHeapFloatVectorValues copy() { - return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone()); + return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone()); } @Override @@ -433,6 +442,22 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { lastOrd = targetOrd; return value; } + + @Override + public VectorScorer scorer(float[] target) { + OffHeapFloatVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.similarityFunction.compare(values.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } /** Read the nearest-neighbors graph from the index input */ diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index f29c04ed10c..19675597c0a 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -35,7 +35,9 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -255,7 +257,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { IndexInput bytesSlice = vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); return new OffHeapFloatVectorValues( - fieldEntry.dimension, fieldEntry.size(), fieldEntry.ordToDoc, bytesSlice); + fieldEntry.dimension, + fieldEntry.size(), + fieldEntry.ordToDoc, + fieldEntry.similarityFunction, + bytesSlice); } private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) { @@ -399,16 +405,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { private final IndexInput dataIn; private final int byteSize; private final float[] value; + private final VectorSimilarityFunction similarityFunction; private int ord = -1; private int doc = -1; - OffHeapFloatVectorValues(int dimension, int size, int[] ordToDoc, IndexInput dataIn) { + OffHeapFloatVectorValues( + int dimension, + int size, + int[] ordToDoc, + VectorSimilarityFunction similarityFunction, + IndexInput dataIn) { this.dimension = dimension; this.size = size; this.ordToDoc = ordToDoc; ordToDocOperator = ordToDoc == null ? IntUnaryOperator.identity() : (ord) -> ordToDoc[ord]; this.dataIn = dataIn; + this.similarityFunction = similarityFunction; byteSize = Float.BYTES * dimension; value = new float[dimension]; } @@ -468,7 +481,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { @Override public OffHeapFloatVectorValues copy() { - return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone()); + return new OffHeapFloatVectorValues( + dimension, size, ordToDoc, similarityFunction, dataIn.clone()); } @Override @@ -477,6 +491,22 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { dataIn.readFloats(value, 0, value.length); return value; } + + @Override + public VectorScorer scorer(float[] target) { + OffHeapFloatVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.similarityFunction.compare(values.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } /** Read the nearest-neighbors graph from the index input */ diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index ec0cbf7379a..19dc82cc46d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene92; import java.io.IOException; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; @@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int byteSize; protected int lastOrd = -1; protected final float[] value; + protected final VectorSimilarityFunction vectorSimilarityFunction; + ; - OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) { + OffHeapFloatVectorValues( + int dimension, + int size, + VectorSimilarityFunction vectorSimilarityFunction, + IndexInput slice) { this.dimension = dimension; this.size = size; this.slice = slice; byteSize = Float.BYTES * dimension; value = new float[dimension]; + this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override @@ -75,9 +85,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues vectorData.slice( "vector-data", fieldEntry.vectorDataOffset(), fieldEntry.vectorDataLength()); if (fieldEntry.docsWithFieldOffset() == -1) { - return new DenseOffHeapVectorValues(fieldEntry.dimension(), fieldEntry.size(), bytesSlice); + return new DenseOffHeapVectorValues( + fieldEntry.dimension(), fieldEntry.size(), fieldEntry.similarityFunction(), bytesSlice); } else { - return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice); + return new SparseOffHeapVectorValues( + fieldEntry, vectorData, fieldEntry.similarityFunction(), bytesSlice); } } @@ -85,8 +97,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) { - super(dimension, size, slice); + public DenseOffHeapVectorValues( + int dimension, + int size, + VectorSimilarityFunction vectorSimilarityFunction, + IndexInput slice) { + super(dimension, size, vectorSimilarityFunction, slice); } @Override @@ -115,13 +131,29 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone()); + return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone()); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -132,10 +164,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues private final Lucene92HnswVectorsReader.FieldEntry fieldEntry; public SparseOffHeapVectorValues( - Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice) + Lucene92HnswVectorsReader.FieldEntry fieldEntry, + IndexInput dataIn, + VectorSimilarityFunction vectorSimilarityFunction, + IndexInput slice) throws IOException { - super(fieldEntry.dimension(), fieldEntry.size(), slice); + super(fieldEntry.dimension(), fieldEntry.size(), vectorSimilarityFunction, slice); this.fieldEntry = fieldEntry; final RandomAccessInput addressesData = dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength()); @@ -173,8 +208,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public OffHeapFloatVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone()); + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + fieldEntry, dataIn, vectorSimilarityFunction, slice.clone()); } @Override @@ -199,12 +235,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } }; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null); + super(dimension, 0, VectorSimilarityFunction.COSINE, null); } private int doc = -1; @@ -258,5 +310,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index b961bafabb1..0c909e3839d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -22,6 +22,9 @@ import java.nio.ByteBuffer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; @@ -39,12 +42,19 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; + protected final VectorSimilarityFunction vectorSimilarityFunction; - OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapByteVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; + this.vectorSimilarityFunction = vectorSimilarityFunction; byteBuffer = ByteBuffer.allocate(byteSize); binaryValue = byteBuffer.array(); } @@ -85,9 +95,14 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues int byteSize = fieldEntry.dimension(); if (fieldEntry.docsWithFieldOffset() == -1) { return new DenseOffHeapVectorValues( - fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize); + fieldEntry.dimension(), + fieldEntry.size(), + bytesSlice, + fieldEntry.similarityFunction(), + byteSize); } else { - return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize); + return new SparseOffHeapVectorValues( + fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize); } } @@ -95,8 +110,13 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { + super(dimension, size, slice, vectorSimilarityFunction, byteSize); } @Override @@ -124,14 +144,31 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - public OffHeapByteVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + DenseOffHeapVectorValues copy = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorSimilarityFunction.compare(copy.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues { @@ -145,10 +182,11 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, int byteSize) throws IOException { - super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize); + super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize); this.fieldEntry = fieldEntry; final RandomAccessInput addressesData = dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength()); @@ -186,8 +224,9 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - public OffHeapByteVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize); } @Override @@ -212,12 +251,28 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } }; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + SparseOffHeapVectorValues copy = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorSimilarityFunction.compare(copy.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } private int doc = -1; @@ -271,5 +326,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(byte[] query) { + return null; + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 95abedf2d87..91f97b8a41f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene94; import java.io.IOException; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; @@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int byteSize; protected int lastOrd = -1; protected final float[] value; + protected final VectorSimilarityFunction vectorSimilarityFunction; - OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapFloatVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; value = new float[dimension]; + this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override @@ -81,9 +91,14 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues }; if (fieldEntry.docsWithFieldOffset() == -1) { return new DenseOffHeapVectorValues( - fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize); + fieldEntry.dimension(), + fieldEntry.size(), + bytesSlice, + fieldEntry.similarityFunction(), + byteSize); } else { - return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize); + return new SparseOffHeapVectorValues( + fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize); } } @@ -91,8 +106,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { + super(dimension, size, slice, vectorSimilarityFunction, byteSize); } @Override @@ -120,14 +140,31 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public OffHeapFloatVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -141,10 +178,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, int byteSize) throws IOException { - super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize); + super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize); this.fieldEntry = fieldEntry; final RandomAccessInput addressesData = dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength()); @@ -182,8 +220,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public OffHeapFloatVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize); } @Override @@ -208,12 +247,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } }; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } private int doc = -1; @@ -267,5 +322,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 72b7cfc82f2..ab51f935fb0 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -253,6 +253,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements + VectorEncoding.FLOAT32); } return OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -274,6 +276,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -295,6 +299,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -324,6 +330,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index ce07b254849..39828524d26 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -133,7 +133,10 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter { // build the graph using the temporary vector data Lucene90HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors = new Lucene90HnswVectorsReader.OffHeapFloatVectorValues( - floatVectorValues.dimension(), docIds, vectorDataInput); + floatVectorValues.dimension(), + docIds, + fieldInfo.getVectorSimilarityFunction(), + vectorDataInput); long[] offsets = new long[docIds.length]; long vectorIndexOffset = vectorIndex.getFilePointer(); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java index f3b411cee09..b914acf3fbb 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java @@ -68,4 +68,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testSortedIndexBytes() throws Exception { // unimplemented } + + @Override + public void testByteVectorScorerIteration() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index b58b2e21a4f..37b75250381 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -138,7 +138,11 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter { // TODO: separate random access vector values from DocIdSetIterator? Lucene91HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors = new Lucene91HnswVectorsReader.OffHeapFloatVectorValues( - floatVectorValues.dimension(), docsWithField.cardinality(), null, vectorDataInput); + floatVectorValues.dimension(), + docsWithField.cardinality(), + null, + fieldInfo.getVectorSimilarityFunction(), + vectorDataInput); Lucene91OnHeapHnswGraph graph = offHeapVectors.size() == 0 ? null diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java index b8cae733722..b27a42700cb 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java @@ -67,4 +67,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testSortedIndexBytes() throws Exception { // unimplemented } + + @Override + public void testByteVectorScorerIteration() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index 4dd8f1f3054..caa8fc3da14 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -146,7 +146,10 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter { // TODO: separate random access vector values from DocIdSetIterator? OffHeapFloatVectorValues offHeapVectors = new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - floatVectorValues.dimension(), docsWithField.cardinality(), vectorDataInput); + floatVectorValues.dimension(), + docsWithField.cardinality(), + fieldInfo.getVectorSimilarityFunction(), + vectorDataInput); OnHeapHnswGraph graph = offHeapVectors.size() == 0 ? null diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java index 976191e2484..aaee5abe4ad 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java @@ -57,4 +57,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testSortedIndexBytes() throws Exception { // unimplemented } + + @Override + public void testByteVectorScorerIteration() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index d8cdb1739f7..9726a3b19e8 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -421,6 +421,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, + fieldInfo.getVectorSimilarityFunction(), byteSize); RandomVectorScorerSupplier scorerSupplier = defaultFlatVectorScorer.getRandomVectorScorerSupplier( @@ -437,6 +438,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, + fieldInfo.getVectorSimilarityFunction(), byteSize); RandomVectorScorerSupplier scorerSupplier = defaultFlatVectorScorer.getRandomVectorScorerSupplier( diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index c4e315a72e2..c74d34fb9ad 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -70,6 +70,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { private final IndexOutput meta, vectorData, vectorIndex; private final int M; private final int beamWidth; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final List> fields = new ArrayList<>(); private boolean finished; @@ -437,7 +438,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { OnHeapHnswGraph graph = null; int[][] vectorIndexNodeOffsets = null; if (docsWithField.cardinality() != 0) { - DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); final RandomVectorScorerSupplier scorerSupplier; switch (fieldInfo.getVectorEncoding()) { case BYTE: @@ -448,7 +448,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, - byteSize)); + byteSize, + defaultFlatVectorScorer, + fieldInfo.getVectorSimilarityFunction())); break; case FLOAT32: scorerSupplier = @@ -458,7 +460,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, - byteSize)); + byteSize, + defaultFlatVectorScorer, + fieldInfo.getVectorSimilarityFunction())); break; default: throw new IllegalArgumentException( @@ -667,6 +671,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { private final DocsWithFieldSet docsWithField; private final List vectors; private final HnswGraphBuilder hnswGraphBuilder; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private int lastDocID = -1; private int node = 0; @@ -697,7 +702,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); - DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 8ea9b22b35a..b8d2ad5702c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.BufferedChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; @@ -92,7 +93,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { } assert fieldEntries.containsKey(fieldName) == false; fieldEntries.put( - fieldName, new FieldEntry(dimension, vectorDataOffset, vectorDataLength, docIds)); + fieldName, + new FieldEntry( + dimension, + vectorDataOffset, + vectorDataLength, + docIds, + readState.fieldInfos.fieldInfo(fieldName).getVectorSimilarityFunction())); fieldNumber = readInt(in, FIELD_NUMBER); } SimpleTextUtil.checkFooter(in); @@ -275,7 +282,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { } private record FieldEntry( - int dimension, long vectorDataOffset, long vectorDataLength, int[] ordToDoc) { + int dimension, + long vectorDataOffset, + long vectorDataLength, + int[] ordToDoc, + VectorSimilarityFunction similarityFunction) { int size() { return ordToDoc.length; } @@ -298,6 +309,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { readAllVectors(); } + private SimpleTextFloatVectorValues(SimpleTextFloatVectorValues other) { + this.entry = other.entry; + this.in = other.in.clone(); + this.values = other.values; + this.curOrd = other.curOrd; + } + @Override public int dimension() { return entry.dimension; @@ -340,6 +358,25 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { return slowAdvance(target); } + @Override + public VectorScorer scorer(float[] target) { + SimpleTextFloatVectorValues simpleTextFloatVectorValues = + new SimpleTextFloatVectorValues(this); + return new VectorScorer() { + @Override + public float score() throws IOException { + return entry + .similarityFunction() + .compare(simpleTextFloatVectorValues.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return simpleTextFloatVectorValues; + } + }; + } + private void readAllVectors() throws IOException { for (float[] value : values) { readVector(value); @@ -379,6 +416,15 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { readAllVectors(); } + private SimpleTextByteVectorValues(SimpleTextByteVectorValues other) { + this.entry = other.entry; + this.in = other.in.clone(); + this.values = other.values; + this.binaryValue = new BytesRef(entry.dimension); + this.binaryValue.length = binaryValue.bytes.length; + this.curOrd = other.curOrd; + } + @Override public int dimension() { return entry.dimension; @@ -422,6 +468,24 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { return slowAdvance(target); } + @Override + public VectorScorer scorer(byte[] target) { + SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this); + return new VectorScorer() { + @Override + public float score() throws IOException { + return entry + .similarityFunction() + .compare(simpleTextByteVectorValues.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return simpleTextByteVectorValues; + } + }; + } + private void readAllVectors() throws IOException { for (byte[] value : values) { readVector(value); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index d79bbfd9a36..8a9b4816571 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; @@ -159,6 +160,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ @@ -216,6 +222,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } @Override @@ -354,6 +365,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } private static class BufferedByteVectorValues extends ByteVectorValues { @@ -414,5 +430,10 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 078b0c10e69..8ae86d4c807 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Accountable; /** Writes vectors to an index. */ @@ -188,7 +189,6 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private int docId; VectorValuesSub current; @@ -239,6 +239,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { public int dimension() { return subs.get(0).values.dimension(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } static class MergedByteVectorValues extends ByteVectorValues { @@ -296,6 +301,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { public int dimension() { return subs.get(0).values.dimension(); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 8d98a9cd1c8..b1f0972f4f9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -19,13 +19,18 @@ package org.apache.lucene.codecs.lucene95; import java.io.IOException; import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ @@ -39,14 +44,24 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; - OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapByteVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; byteBuffer = ByteBuffer.allocate(byteSize); binaryValue = byteBuffer.array(); + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; } @Override @@ -79,6 +94,8 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues } public static OffHeapByteVectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, OrdToDocDISIReaderConfiguration configuration, VectorEncoding vectorEncoding, int dimension, @@ -87,14 +104,26 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues IndexInput vectorData) throws IOException { if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) { - return new EmptyOffHeapVectorValues(dimension); + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, dimension); + return new DenseOffHeapVectorValues( + dimension, + configuration.size, + bytesSlice, + dimension, + flatVectorsScorer, + vectorSimilarityFunction); } else { return new SparseOffHeapVectorValues( - configuration, vectorData, bytesSlice, dimension, dimension); + configuration, + vectorData, + bytesSlice, + dimension, + dimension, + flatVectorsScorer, + vectorSimilarityFunction); } } @@ -106,8 +135,14 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction) { + super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction); } @Override @@ -136,13 +171,32 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + RandomVectorScorer scorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(copy.doc); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues { @@ -157,10 +211,18 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues IndexInput dataIn, IndexInput slice, int dimension, - int byteSize) + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction) throws IOException { - super(dimension, configuration.size, slice, byteSize); + super( + dimension, + configuration.size, + slice, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -200,7 +262,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dataIn, slice.clone(), dimension, byteSize); + configuration, + dataIn, + slice.clone(), + dimension, + byteSize, + flatVectorsScorer, + similarityFunction); } @Override @@ -225,12 +293,33 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues } }; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + RandomVectorScorer scorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(copy.disi.index()); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { - public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + public EmptyOffHeapVectorValues( + int dimension, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction); } private int doc = -1; @@ -284,5 +373,10 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(byte[] query) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 0aeddaf1536..52753325f6b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -18,13 +18,18 @@ package org.apache.lucene.codecs.lucene95; import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ @@ -37,12 +42,22 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int byteSize; protected int lastOrd = -1; protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; - OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapFloatVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; value = new float[dimension]; } @@ -73,6 +88,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } public static OffHeapFloatVectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, OrdToDocDISIReaderConfiguration configuration, VectorEncoding vectorEncoding, int dimension, @@ -81,15 +98,27 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues IndexInput vectorData) throws IOException { if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) { - return new EmptyOffHeapVectorValues(dimension); + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); int byteSize = dimension * Float.BYTES; if (configuration.docsWithFieldOffset == -1) { - return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize); + return new DenseOffHeapVectorValues( + dimension, + configuration.size, + bytesSlice, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction); } else { return new SparseOffHeapVectorValues( - configuration, vectorData, bytesSlice, dimension, byteSize); + configuration, + vectorData, + bytesSlice, + dimension, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction); } } @@ -101,8 +130,14 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); } @Override @@ -131,13 +166,32 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + RandomVectorScorer randomVectorScorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(copy.doc); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -152,10 +206,12 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues IndexInput dataIn, IndexInput slice, int dimension, - int byteSize) + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) throws IOException { - super(dimension, configuration.size, slice, byteSize); + super(dimension, configuration.size, slice, byteSize, flatVectorsScorer, similarityFunction); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -195,7 +251,13 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dataIn, slice.clone(), dimension, byteSize); + configuration, + dataIn, + slice.clone(), + dimension, + byteSize, + flatVectorsScorer, + similarityFunction); } @Override @@ -220,12 +282,33 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } }; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + RandomVectorScorer randomVectorScorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(copy.disi.index()); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { - public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + public EmptyOffHeapVectorValues( + int dimension, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); } private int doc = -1; @@ -256,17 +339,17 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public int advance(int target) throws IOException { + public int advance(int target) { return doc = NO_MORE_DOCS; } @Override - public EmptyOffHeapVectorValues copy() throws IOException { + public EmptyOffHeapVectorValues copy() { throw new UnsupportedOperationException(); } @Override - public float[] vectorValue(int targetOrd) throws IOException { + public float[] vectorValue(int targetOrd) { throw new UnsupportedOperationException(); } @@ -279,5 +362,10 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(float[] query) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index 0311a3b0cf8..3a52afd5a9f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -185,6 +185,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { + VectorEncoding.FLOAT32); } return OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -206,6 +208,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -223,6 +227,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -241,6 +247,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 0f4a4114e70..288a6ae6df9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -314,14 +314,18 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Byte.BYTES)); + fieldInfo.getVectorDimension() * Byte.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction())); case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Float.BYTES)); + fieldInfo.getVectorDimension() * Float.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction())); }; return new FlatCloseableRandomVectorScorerSupplier( () -> { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 8aaa2cca7b5..f0ba77854c6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -164,7 +165,24 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - return rawVectorsReader.getFloatVectorValues(field); + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return null; + } + final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field); + OffHeapQuantizedByteVectorValues quantizedByteVectorValues = + OffHeapQuantizedByteVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.scalarQuantizer, + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.compress, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + quantizedVectorData); + return new QuantizedVectorValues(rawVectorValues, quantizedByteVectorValues); } @Override @@ -227,6 +245,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade fieldEntry.dimension, fieldEntry.size, fieldEntry.scalarQuantizer, + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, @@ -282,6 +302,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade fieldEntry.dimension, fieldEntry.size, fieldEntry.scalarQuantizer, + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, @@ -369,4 +391,56 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc); } } + + private static final class QuantizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final OffHeapQuantizedByteVectorValues quantizedVectorValues; + + QuantizedVectorValues( + FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue() throws IOException { + return rawVectorValues.vectorValue(); + } + + @Override + public int docID() { + return rawVectorValues.docID(); + } + + @Override + public int nextDoc() throws IOException { + int rawDocId = rawVectorValues.nextDoc(); + int quantizedDocId = quantizedVectorValues.nextDoc(); + assert rawDocId == quantizedDocId; + return quantizedDocId; + } + + @Override + public int advance(int target) throws IOException { + int rawDocId = rawVectorValues.advance(target); + int quantizedDocId = quantizedVectorValues.advance(target); + assert rawDocId == quantizedDocId; + return quantizedDocId; + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.vectorScorer(query); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index c4067b7fd78..278e60e0976 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -49,6 +49,7 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; @@ -526,6 +527,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite docsWithField.cardinality(), mergedQuantizationState, compress, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, quantizationDataInput))); } finally { if (success == false) { @@ -890,6 +893,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite curDoc = target; return docID(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { @@ -1013,6 +1021,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite public float getScoreCorrectionConstant() throws IOException { return current.values.getScoreCorrectionConstant(); } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } } static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { @@ -1082,6 +1095,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite return doc; } + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + private void quantize() throws IOException { if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); @@ -1182,5 +1200,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite public int advance(int target) throws IOException { return in.advance(target); } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 9659eb13187..5fee4b1cb08 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -19,10 +19,15 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; @@ -39,6 +44,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect protected final int size; protected final int numBytes; protected final ScalarQuantizer scalarQuantizer; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer vectorsScorer; protected final boolean compress; protected final IndexInput slice; @@ -85,6 +92,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect int dimension, int size, ScalarQuantizer scalarQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, boolean compress, IndexInput slice) { this.dimension = dimension; @@ -100,6 +109,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect this.byteSize = this.numBytes + Float.BYTES; byteBuffer = ByteBuffer.allocate(dimension); binaryValue = byteBuffer.array(); + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; } @Override @@ -160,22 +171,39 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect int dimension, int size, ScalarQuantizer scalarQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, boolean compress, long quantizedVectorDataOffset, long quantizedVectorDataLength, IndexInput vectorData) throws IOException { if (configuration.isEmpty()) { - return new EmptyOffHeapVectorValues(dimension); + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); } IndexInput bytesSlice = vectorData.slice( "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, bytesSlice); + return new DenseOffHeapVectorValues( + dimension, + size, + scalarQuantizer, + compress, + similarityFunction, + vectorsScorer, + bytesSlice); } else { return new SparseOffHeapVectorValues( - configuration, dimension, size, scalarQuantizer, compress, vectorData, bytesSlice); + configuration, + dimension, + size, + scalarQuantizer, + compress, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); } } @@ -192,8 +220,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect int size, ScalarQuantizer scalarQuantizer, boolean compress, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, IndexInput slice) { - super(dimension, size, scalarQuantizer, compress, slice); + super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); } @Override @@ -223,13 +253,37 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, scalarQuantizer, compress, slice.clone()); + dimension, + size, + scalarQuantizer, + compress, + similarityFunction, + vectorsScorer, + slice.clone()); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + RandomVectorScorer vectorScorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorScorer.score(copy.doc); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { @@ -246,9 +300,11 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect ScalarQuantizer scalarQuantizer, boolean compress, IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, IndexInput slice) throws IOException { - super(dimension, size, scalarQuantizer, compress, slice); + super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); this.configuration = configuration; this.dataIn = dataIn; this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); @@ -279,7 +335,15 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dimension, size, scalarQuantizer, compress, dataIn, slice.clone()); + configuration, + dimension, + size, + scalarQuantizer, + compress, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); } @Override @@ -304,12 +368,40 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect } }; } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + RandomVectorScorer vectorScorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorScorer.score(copy.disi.index()); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { - public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), false, null); + public EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super( + dimension, + 0, + new ScalarQuantizer(-1, 1, (byte) 7), + similarityFunction, + vectorsScorer, + false, + null); } private int doc = -1; @@ -363,5 +455,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer vectorScorer(float[] target) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 4792a7d8474..d04d52b0dcf 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -19,6 +19,7 @@ package org.apache.lucene.index; import java.io.IOException; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -75,4 +76,14 @@ public abstract class ByteVectorValues extends DocIdSetIterator { + ")"); } } + + /** + * Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not + * the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and + * iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}. + * + * @param query the query vector + * @return a {@link VectorScorer} instance + */ + public abstract VectorScorer scorer(byte[] query) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index 84f9868f6ef..ca2cb1a27d4 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.FilterLeafReader.FilterTerms; import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -476,6 +477,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader { return vectorValues.size(); } + @Override + public VectorScorer scorer(float[] target) throws IOException { + return vectorValues.scorer(target); + } + /** * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or * if {@link Thread#interrupted()} returns true. @@ -543,6 +549,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader { return vectorValues.vectorValue(); } + @Override + public VectorScorer scorer(byte[] target) throws IOException { + return vectorValues.scorer(target); + } + /** * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or * if {@link Thread#interrupted()} returns true. diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 61f0157ddfd..9a5bb31b0c6 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -19,6 +19,7 @@ package org.apache.lucene.index; import java.io.IOException; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -75,4 +76,15 @@ public abstract class FloatVectorValues extends DocIdSetIterator { + ")"); } } + + /** + * Return a {@link VectorScorer} for the given query vector and the current {@link + * FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for + * this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the + * iteration of this {@link FloatVectorValues}. + * + * @param query the query vector + * @return a {@link VectorScorer} instance + */ + public abstract VectorScorer scorer(float[] query) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index d67d7daa963..f01a9bb966b 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -36,6 +36,7 @@ import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues; import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -883,6 +884,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader { public int advance(int target) throws IOException { return mergedIterator.advance(target); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } }; } @@ -937,6 +943,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader { public int advance(int target) throws IOException { return mergedIterator.advance(target); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index b6bc6bd5234..ae943735980 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -35,6 +35,7 @@ import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IOSupplier; @@ -266,6 +267,11 @@ public final class SortingCodecReader extends FilterCodecReader { } return docId = docsWithField.nextSetBit(target); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } private static class SortingByteVectorValues extends ByteVectorValues { @@ -320,6 +326,11 @@ public final class SortingCodecReader extends FilterCodecReader { } return docId = docsWithField.nextSetBit(target); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } /** diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 1dec9da3042..7a4ac9d3f40 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -205,17 +205,17 @@ abstract class AbstractKnnVectorQuery extends Query { HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; ScoreDoc topDoc = queue.top(); + DocIdSetIterator vectorIterator = vectorScorer.iterator(); + DocIdSetIterator conjunction = + ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of()); int doc; - while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { // Mark results as partial if timeout is met if (queryTimeout != null && queryTimeout.shouldExit()) { relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; break; } - - boolean advanced = vectorScorer.advanceExact(doc); - assert advanced; - + assert vectorIterator.docID() == doc; float score = vectorScorer.score(); if (score > topDoc.score) { topDoc.score = score; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index 4aea90cee38..7bd21d99b30 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -19,6 +19,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import java.util.Objects; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -83,7 +84,10 @@ abstract class AbstractVectorSimilarityQuery extends Query { VectorScorer scorer = createVectorScorer(context); if (scorer == null) { return Explanation.noMatch("Not indexed as the correct vector field"); - } else if (scorer.advanceExact(doc)) { + } + DocIdSetIterator iterator = scorer.iterator(); + int docId = iterator.advance(doc); + if (docId == doc) { float score = scorer.score(); if (score >= resultSimilarity) { return Explanation.match(boost * score, "Score above threshold"); @@ -256,15 +260,15 @@ abstract class AbstractVectorSimilarityQuery extends Query { DocIdSetIterator acceptDocs, float threshold) { float[] cachedScore = new float[1]; + DocIdSetIterator vectorIterator = scorer.iterator(); + DocIdSetIterator conjunction = + ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptDocs), List.of()); DocIdSetIterator iterator = - new FilteredDocIdSetIterator(acceptDocs) { + new FilteredDocIdSetIterator(conjunction) { @Override protected boolean match(int doc) throws IOException { // Advance the scorer - if (!scorer.advanceExact(doc)) { - return false; - } - + assert doc == vectorIterator.docID(); // Compute the dot product float score = scorer.score(); cachedScore[0] = score * boost; diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java index e410ad06343..bd2190121ab 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -21,9 +21,8 @@ import java.util.Arrays; import java.util.Locale; import java.util.Objects; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.Bits; /** @@ -98,12 +97,11 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { - @SuppressWarnings("resource") - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) { + ByteVectorValues vectorValues = context.reader().getByteVectorValues(field); + if (vectorValues == null) { return null; } - return VectorScorer.create(context, fi, target); + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java index 89d029ed140..b90f12bb4d3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; /** * A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector @@ -37,26 +36,13 @@ class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource { } @Override - public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + public VectorScorer getScorer(LeafReaderContext ctx) throws IOException { final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); if (vectorValues == null) { ByteVectorValues.checkField(ctx.reader(), fieldName); - return DoubleValues.EMPTY; + return null; } - VectorSimilarityFunction function = - ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); - return new DoubleValues() { - @Override - public double doubleValue() throws IOException { - return function.compare(queryVector, vectorValues.vectorValue()); - } - - @Override - public boolean advanceExact(int doc) throws IOException { - return doc >= vectorValues.docID() - && (vectorValues.docID() == doc || vectorValues.advance(doc) == doc); - } - }; + return vectorValues.scorer(queryVector); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java index 44d06c163c0..3dc92482a77 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -21,9 +21,8 @@ import java.util.Arrays; import java.util.Locale; import java.util.Objects; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; @@ -100,11 +99,11 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { @SuppressWarnings("resource") - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { return null; } - return VectorScorer.create(context, fi, target); + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java index 8198467fc50..3bf4c0a1887 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; /** * A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector @@ -44,22 +43,32 @@ class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource { FloatVectorValues.checkField(ctx.reader(), fieldName); return DoubleValues.EMPTY; } - VectorSimilarityFunction function = - ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); return new DoubleValues() { + private final VectorScorer scorer = vectorValues.scorer(queryVector); + private final DocIdSetIterator iterator = scorer.iterator(); + @Override public double doubleValue() throws IOException { - return function.compare(queryVector, vectorValues.vectorValue()); + return scorer.score(); } @Override public boolean advanceExact(int doc) throws IOException { - return doc >= vectorValues.docID() - && (vectorValues.docID() == doc || vectorValues.advance(doc) == doc); + return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc); } }; } + @Override + public VectorScorer getScorer(LeafReaderContext ctx) throws IOException { + final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + if (vectorValues == null) { + FloatVectorValues.checkField(ctx.reader(), fieldName); + return null; + } + return vectorValues.scorer(queryVector); + } + @Override public int hashCode() { return Objects.hash(fieldName, Arrays.hashCode(queryVector)); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index ba94a243e07..e3d733e516f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -98,7 +98,12 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { - return VectorScorer.create(context, fi, target); + ByteVectorValues vectorValues = context.reader().getByteVectorValues(field); + if (vectorValues == null) { + ByteVectorValues.checkField(context.reader(), field); + return null; + } + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index e6e38192e74..91cf4474e1c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -99,7 +99,12 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { - return VectorScorer.create(context, fi, target); + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { + FloatVectorValues.checkField(context.reader(), field); + return null; + } + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java index 249c7e88d87..c9b8362ae39 100644 --- a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java @@ -17,117 +17,25 @@ package org.apache.lucene.search; import java.io.IOException; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; /** * Computes the similarity score between a given query vector and different document vectors. This - * is primarily used by {@link KnnFloatVectorQuery} to run an exact, exhaustive search over the - * vectors. + * is used for exact searching and scoring + * + * @lucene.experimental */ -abstract class VectorScorer { - protected final VectorSimilarityFunction similarity; +public interface VectorScorer { /** - * Create a new vector scorer instance. + * Compute the score for the current document ID. * - * @param context the reader context - * @param fi the FieldInfo for the field containing document vectors - * @param query the query vector to compute the similarity for + * @return the score for the current document ID + * @throws IOException if an exception occurs during score computation */ - static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query) - throws IOException { - FloatVectorValues values = context.reader().getFloatVectorValues(fi.name); - if (values == null) { - FloatVectorValues.checkField(context.reader(), fi.name); - return null; - } - final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); - return new FloatVectorScorer(values, query, similarity); - } + float score() throws IOException; - static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query) - throws IOException { - ByteVectorValues values = context.reader().getByteVectorValues(fi.name); - if (values == null) { - ByteVectorValues.checkField(context.reader(), fi.name); - return null; - } - VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); - return new ByteVectorScorer(values, query, similarity); - } - - VectorScorer(VectorSimilarityFunction similarity) { - this.similarity = similarity; - } - - /** Compute the similarity score for the current document. */ - abstract float score() throws IOException; - - abstract boolean advanceExact(int doc) throws IOException; - - private static class ByteVectorScorer extends VectorScorer { - private final byte[] query; - private final ByteVectorValues values; - - protected ByteVectorScorer( - ByteVectorValues values, byte[] query, VectorSimilarityFunction similarity) { - super(similarity); - this.values = values; - this.query = query; - } - - /** - * Advance the instance to the given document ID and return true if there is a value for that - * document. - */ - @Override - public boolean advanceExact(int doc) throws IOException { - int vectorDoc = values.docID(); - if (vectorDoc < doc) { - vectorDoc = values.advance(doc); - } - return vectorDoc == doc; - } - - @Override - public float score() throws IOException { - assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned"; - return similarity.compare(query, values.vectorValue()); - } - } - - private static class FloatVectorScorer extends VectorScorer { - private final float[] query; - private final FloatVectorValues values; - - protected FloatVectorScorer( - FloatVectorValues values, float[] query, VectorSimilarityFunction similarity) { - super(similarity); - this.query = query; - this.values = values; - } - - /** - * Advance the instance to the given document ID and return true if there is a value for that - * document. - */ - @Override - public boolean advanceExact(int doc) throws IOException { - int vectorDoc = values.docID(); - if (vectorDoc < doc) { - vectorDoc = values.advance(doc); - } - return vectorDoc == doc; - } - - @Override - public float score() throws IOException { - assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned"; - return similarity.compare(query, values.vectorValue()); - } - } + /** + * @return a {@link DocIdSetIterator} over the documents. + */ + DocIdSetIterator iterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java index 639e225d665..4e53d580cda 100644 --- a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java @@ -33,8 +33,26 @@ abstract class VectorSimilarityValuesSource extends DoubleValuesSource { } @Override - public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) - throws IOException; + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + VectorScorer scorer = getScorer(ctx); + if (scorer == null) { + return DoubleValues.EMPTY; + } + DocIdSetIterator iterator = scorer.iterator(); + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + return scorer.score(); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc); + } + }; + } + + protected abstract VectorScorer getScorer(LeafReaderContext ctx) throws IOException; @Override public boolean needsScores() { diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index fa029b0f5ae..c277abfc634 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -18,6 +18,8 @@ package org.apache.lucene.util.quantization; import java.io.IOException; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -25,6 +27,31 @@ import org.apache.lucene.index.ByteVectorValues; * * @lucene.experimental */ -public abstract class QuantizedByteVectorValues extends ByteVectorValues { +public abstract class QuantizedByteVectorValues extends DocIdSetIterator { public abstract float getScoreCorrectionConstant() throws IOException; + + public abstract byte[] vectorValue() throws IOException; + + /** Return the dimension of the vectors */ + public abstract int dimension(); + + /** + * Return the number of vectors for this field. + * + * @return the number of vectors returned by this iterator + */ + public abstract int size(); + + @Override + public final long cost() { + return size(); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance + */ + public abstract VectorScorer vectorScorer(float[] query) throws IOException; } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java index 6097a49151d..5432879a134 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java @@ -26,7 +26,6 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.StringField; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; @@ -44,24 +43,22 @@ public class TestVectorScorer extends LuceneTestCase { IndexReader reader = DirectoryReader.open(indexStore)) { assert reader.leaves().size() == 1; LeafReaderContext context = reader.leaves().get(0); - FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field"); final VectorScorer vectorScorer; switch (encoding) { case BYTE: - vectorScorer = VectorScorer.create(context, fieldInfo, new byte[] {1, 2}); + vectorScorer = context.reader().getByteVectorValues("field").scorer(new byte[] {1, 2}); break; case FLOAT32: - vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2}); + vectorScorer = context.reader().getFloatVectorValues("field").scorer(new float[] {1, 2}); break; default: throw new IllegalArgumentException("unexpected vector encoding: " + encoding); } + DocIdSetIterator iterator = vectorScorer.iterator(); int numDocs = 0; - for (int i = 0; i < reader.maxDoc(); i++) { - if (vectorScorer.advanceExact(i)) { - numDocs++; - } + while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + numDocs++; } assertEquals(3, numDocs); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 4770bdf98ab..7b50a4e1ce4 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -72,6 +72,7 @@ import org.apache.lucene.search.SortField; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -1128,6 +1129,11 @@ abstract class HnswGraphTestCase extends LuceneTestCase { public float[] vectorValue(int ord) { return unitVector2d(ord / (double) size, value); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } /** Returns vectors evenly distributed around the upper unit semicircle. */ @@ -1193,6 +1199,11 @@ abstract class HnswGraphTestCase extends LuceneTestCase { } return bValue; } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } private static float[] unitVector2d(double piRadians) { diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 9d150373070..d1930569f6c 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.Set; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.tests.util.LuceneTestCase; public class TestScalarQuantizer extends LuceneTestCase { @@ -262,5 +263,10 @@ public class TestScalarQuantizer extends LuceneTestCase { curDoc = target - 1; return nextDoc(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index 735aa5ac10f..8c6f0f98470 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -20,10 +20,8 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.IndexSearcher; @@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; @@ -92,14 +91,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { return NO_RESULTS; } - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - ParentBlockJoinByteVectorScorer vectorScorer = - new ParentBlockJoinByteVectorScorer( - byteVectorValues, - acceptIterator, - parentBitSet, - query, - fi.getVectorSimilarityFunction()); + VectorScorer scorer = byteVectorValues.scorer(query); + DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer = + new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer( + acceptIterator, parentBitSet, scorer); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; @@ -177,59 +172,4 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { result = 31 * result + Arrays.hashCode(query); return result; } - - private static class ParentBlockJoinByteVectorScorer { - private final byte[] query; - private final ByteVectorValues values; - private final VectorSimilarityFunction similarity; - private final DocIdSetIterator acceptedChildrenIterator; - private final BitSet parentBitSet; - private int currentParent = -1; - private int bestChild = -1; - private float currentScore = Float.NEGATIVE_INFINITY; - - protected ParentBlockJoinByteVectorScorer( - ByteVectorValues values, - DocIdSetIterator acceptedChildrenIterator, - BitSet parentBitSet, - byte[] query, - VectorSimilarityFunction similarity) { - this.query = query; - this.values = values; - this.similarity = similarity; - this.acceptedChildrenIterator = acceptedChildrenIterator; - this.parentBitSet = parentBitSet; - } - - public int bestChild() { - return bestChild; - } - - public int nextParent() throws IOException { - int nextChild = acceptedChildrenIterator.docID(); - if (nextChild == -1) { - nextChild = acceptedChildrenIterator.nextDoc(); - } - if (nextChild == DocIdSetIterator.NO_MORE_DOCS) { - currentParent = DocIdSetIterator.NO_MORE_DOCS; - return currentParent; - } - currentScore = Float.NEGATIVE_INFINITY; - currentParent = parentBitSet.nextSetBit(nextChild); - do { - values.advance(nextChild); - float score = similarity.compare(query, values.vectorValue()); - if (score > currentScore) { - bestChild = nextChild; - currentScore = score; - } - } while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS - && nextChild < currentParent); - return currentParent; - } - - public float score() throws IOException { - return currentScore; - } - } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 8a427f467dd..513da619cae 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -19,11 +19,9 @@ package org.apache.lucene.search.join; import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.IndexSearcher; @@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; @@ -92,14 +91,9 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery return NO_RESULTS; } - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - DiversifyingChildrenFloatVectorScorer vectorScorer = - new DiversifyingChildrenFloatVectorScorer( - floatVectorValues, - acceptIterator, - parentBitSet, - query, - fi.getVectorSimilarityFunction()); + DiversifyingChildrenVectorScorer vectorScorer = + new DiversifyingChildrenVectorScorer( + acceptIterator, parentBitSet, floatVectorValues.scorer(query)); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; @@ -178,26 +172,20 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery return result; } - private static class DiversifyingChildrenFloatVectorScorer { - private final float[] query; - private final FloatVectorValues values; - private final VectorSimilarityFunction similarity; + static class DiversifyingChildrenVectorScorer { + private final VectorScorer vectorScorer; + private final DocIdSetIterator vectorIterator; private final DocIdSetIterator acceptedChildrenIterator; private final BitSet parentBitSet; private int currentParent = -1; private int bestChild = -1; private float currentScore = Float.NEGATIVE_INFINITY; - protected DiversifyingChildrenFloatVectorScorer( - FloatVectorValues values, - DocIdSetIterator acceptedChildrenIterator, - BitSet parentBitSet, - float[] query, - VectorSimilarityFunction similarity) { - this.query = query; - this.values = values; - this.similarity = similarity; + protected DiversifyingChildrenVectorScorer( + DocIdSetIterator acceptedChildrenIterator, BitSet parentBitSet, VectorScorer vectorScorer) { this.acceptedChildrenIterator = acceptedChildrenIterator; + this.vectorScorer = vectorScorer; + this.vectorIterator = vectorScorer.iterator(); this.parentBitSet = parentBitSet; } @@ -217,8 +205,8 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery currentScore = Float.NEGATIVE_INFINITY; currentParent = parentBitSet.nextSetBit(nextChild); do { - values.advance(nextChild); - float score = similarity.compare(query, values.vectorValue()); + vectorIterator.advance(nextChild); + float score = vectorScorer.score(); if (score > currentScore) { bestChild = nextChild; currentScore = score; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 4fb5e95247d..75cb1a52e0d 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -48,10 +48,12 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.tests.util.TestUtil; @@ -718,6 +720,116 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe } } + public void testFloatVectorScorerIteration() throws Exception { + IndexWriterConfig iwc = newIndexWriterConfig(); + if (random().nextBoolean()) { + iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT))); + } + String fieldName = "field"; + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, iwc)) { + int numDoc = atLeast(100); + int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } + float[][] values = new float[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextInt(7) != 3) { + // usually index a vector value for a doc + values[i] = randomNormalizedVector(dimension); + } + add(iw, fieldName, i, values[i], similarityFunction); + if (random().nextInt(10) == 2) { + iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1)))); + } + if (random().nextInt(10) == 3) { + iw.commit(); + } + } + float[] vectorToScore = randomNormalizedVector(dimension); + try (IndexReader reader = DirectoryReader.open(iw)) { + for (LeafReaderContext ctx : reader.leaves()) { + FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + if (vectorValues == null) { + continue; + } + VectorScorer scorer = vectorValues.scorer(vectorToScore); + assertNotNull(scorer); + DocIdSetIterator iterator = scorer.iterator(); + assertSame(iterator, scorer.iterator()); + assertNotSame(iterator, scorer); + // verify scorer iteration scores are valid & iteration with vectorValues is consistent + while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + float score = scorer.score(); + assertTrue(score >= 0f); + assertEquals(iterator.docID(), vectorValues.docID()); + } + // verify that a new scorer can be obtained after iteration + VectorScorer newScorer = vectorValues.scorer(vectorToScore); + assertNotNull(newScorer); + assertNotSame(scorer, newScorer); + assertNotSame(iterator, newScorer.iterator()); + } + } + } + } + + public void testByteVectorScorerIteration() throws Exception { + IndexWriterConfig iwc = newIndexWriterConfig(); + if (random().nextBoolean()) { + iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT))); + } + String fieldName = "field"; + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, iwc)) { + int numDoc = atLeast(100); + int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } + byte[][] values = new byte[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextInt(7) != 3) { + // usually index a vector value for a doc + values[i] = randomVector8(dimension); + } + add(iw, fieldName, i, values[i], similarityFunction); + if (random().nextInt(10) == 2) { + iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1)))); + } + if (random().nextInt(10) == 3) { + iw.commit(); + } + } + byte[] vectorToScore = randomVector8(dimension); + try (IndexReader reader = DirectoryReader.open(iw)) { + for (LeafReaderContext ctx : reader.leaves()) { + ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); + if (vectorValues == null) { + continue; + } + VectorScorer scorer = vectorValues.scorer(vectorToScore); + assertNotNull(scorer); + DocIdSetIterator iterator = scorer.iterator(); + assertSame(iterator, scorer.iterator()); + assertNotSame(iterator, scorer); + // verify scorer iteration scores are valid & iteration with vectorValues is consistent + while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + float score = scorer.score(); + assertTrue(score >= 0f); + assertEquals(iterator.docID(), vectorValues.docID()); + } + // verify that a new scorer can be obtained after iteration + VectorScorer newScorer = vectorValues.scorer(vectorToScore); + assertNotNull(newScorer); + assertNotSame(scorer, newScorer); + assertNotSame(iterator, newScorer.iterator()); + } + } + } + } + protected VectorSimilarityFunction randomSimilarity() { return VectorSimilarityFunction.values()[ random().nextInt(VectorSimilarityFunction.values().length)];