diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 35f52d1c796..3650c13ab39 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -258,6 +258,9 @@ New Features * GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add new Hnsw format for binary quantized vectors. (Ben Trent) +* GITHUB#13181: Add new VectorScorer interface to vector value iterators. This allows for vector codecs to supply + simpler and more optimized vector scoring when iterating vector values directly. (Ben Trent) + Improvements --------------------- diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index ea63c926ac1..6518b4aeab7 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -34,7 +34,9 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -272,7 +274,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { throws IOException { IndexInput bytesSlice = vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); - return new OffHeapFloatVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice); + return new OffHeapFloatVectorValues( + fieldEntry.dimension, fieldEntry.ordToDoc, fieldEntry.similarityFunction, bytesSlice); } private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) { @@ -359,14 +362,20 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { final int byteSize; int lastOrd = -1; final float[] value; + final VectorSimilarityFunction similarityFunction; int ord = -1; int doc = -1; - OffHeapFloatVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) { + OffHeapFloatVectorValues( + int dimension, + int[] ordToDoc, + VectorSimilarityFunction similarityFunction, + IndexInput dataIn) { this.dimension = dimension; this.ordToDoc = ordToDoc; this.dataIn = dataIn; + this.similarityFunction = similarityFunction; byteSize = Float.BYTES * dimension; value = new float[dimension]; @@ -420,7 +429,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { @Override public OffHeapFloatVectorValues copy() { - return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone()); + return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone()); } @Override @@ -433,6 +442,22 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { lastOrd = targetOrd; return value; } + + @Override + public VectorScorer scorer(float[] target) { + OffHeapFloatVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.similarityFunction.compare(values.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } /** Read the nearest-neighbors graph from the index input */ diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index f29c04ed10c..19675597c0a 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -35,7 +35,9 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -255,7 +257,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { IndexInput bytesSlice = vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); return new OffHeapFloatVectorValues( - fieldEntry.dimension, fieldEntry.size(), fieldEntry.ordToDoc, bytesSlice); + fieldEntry.dimension, + fieldEntry.size(), + fieldEntry.ordToDoc, + fieldEntry.similarityFunction, + bytesSlice); } private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) { @@ -399,16 +405,23 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { private final IndexInput dataIn; private final int byteSize; private final float[] value; + private final VectorSimilarityFunction similarityFunction; private int ord = -1; private int doc = -1; - OffHeapFloatVectorValues(int dimension, int size, int[] ordToDoc, IndexInput dataIn) { + OffHeapFloatVectorValues( + int dimension, + int size, + int[] ordToDoc, + VectorSimilarityFunction similarityFunction, + IndexInput dataIn) { this.dimension = dimension; this.size = size; this.ordToDoc = ordToDoc; ordToDocOperator = ordToDoc == null ? IntUnaryOperator.identity() : (ord) -> ordToDoc[ord]; this.dataIn = dataIn; + this.similarityFunction = similarityFunction; byteSize = Float.BYTES * dimension; value = new float[dimension]; } @@ -468,7 +481,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { @Override public OffHeapFloatVectorValues copy() { - return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone()); + return new OffHeapFloatVectorValues( + dimension, size, ordToDoc, similarityFunction, dataIn.clone()); } @Override @@ -477,6 +491,22 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { dataIn.readFloats(value, 0, value.length); return value; } + + @Override + public VectorScorer scorer(float[] target) { + OffHeapFloatVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.similarityFunction.compare(values.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } /** Read the nearest-neighbors graph from the index input */ diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index ec0cbf7379a..19dc82cc46d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene92; import java.io.IOException; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; @@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int byteSize; protected int lastOrd = -1; protected final float[] value; + protected final VectorSimilarityFunction vectorSimilarityFunction; + ; - OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) { + OffHeapFloatVectorValues( + int dimension, + int size, + VectorSimilarityFunction vectorSimilarityFunction, + IndexInput slice) { this.dimension = dimension; this.size = size; this.slice = slice; byteSize = Float.BYTES * dimension; value = new float[dimension]; + this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override @@ -75,9 +85,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues vectorData.slice( "vector-data", fieldEntry.vectorDataOffset(), fieldEntry.vectorDataLength()); if (fieldEntry.docsWithFieldOffset() == -1) { - return new DenseOffHeapVectorValues(fieldEntry.dimension(), fieldEntry.size(), bytesSlice); + return new DenseOffHeapVectorValues( + fieldEntry.dimension(), fieldEntry.size(), fieldEntry.similarityFunction(), bytesSlice); } else { - return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice); + return new SparseOffHeapVectorValues( + fieldEntry, vectorData, fieldEntry.similarityFunction(), bytesSlice); } } @@ -85,8 +97,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) { - super(dimension, size, slice); + public DenseOffHeapVectorValues( + int dimension, + int size, + VectorSimilarityFunction vectorSimilarityFunction, + IndexInput slice) { + super(dimension, size, vectorSimilarityFunction, slice); } @Override @@ -115,13 +131,29 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone()); + return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone()); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -132,10 +164,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues private final Lucene92HnswVectorsReader.FieldEntry fieldEntry; public SparseOffHeapVectorValues( - Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice) + Lucene92HnswVectorsReader.FieldEntry fieldEntry, + IndexInput dataIn, + VectorSimilarityFunction vectorSimilarityFunction, + IndexInput slice) throws IOException { - super(fieldEntry.dimension(), fieldEntry.size(), slice); + super(fieldEntry.dimension(), fieldEntry.size(), vectorSimilarityFunction, slice); this.fieldEntry = fieldEntry; final RandomAccessInput addressesData = dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength()); @@ -173,8 +208,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public OffHeapFloatVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone()); + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + fieldEntry, dataIn, vectorSimilarityFunction, slice.clone()); } @Override @@ -199,12 +235,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } }; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null); + super(dimension, 0, VectorSimilarityFunction.COSINE, null); } private int doc = -1; @@ -258,5 +310,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index b961bafabb1..0c909e3839d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -22,6 +22,9 @@ import java.nio.ByteBuffer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; @@ -39,12 +42,19 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; + protected final VectorSimilarityFunction vectorSimilarityFunction; - OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapByteVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; + this.vectorSimilarityFunction = vectorSimilarityFunction; byteBuffer = ByteBuffer.allocate(byteSize); binaryValue = byteBuffer.array(); } @@ -85,9 +95,14 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues int byteSize = fieldEntry.dimension(); if (fieldEntry.docsWithFieldOffset() == -1) { return new DenseOffHeapVectorValues( - fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize); + fieldEntry.dimension(), + fieldEntry.size(), + bytesSlice, + fieldEntry.similarityFunction(), + byteSize); } else { - return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize); + return new SparseOffHeapVectorValues( + fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize); } } @@ -95,8 +110,13 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { + super(dimension, size, slice, vectorSimilarityFunction, byteSize); } @Override @@ -124,14 +144,31 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - public OffHeapByteVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + DenseOffHeapVectorValues copy = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorSimilarityFunction.compare(copy.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues { @@ -145,10 +182,11 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, int byteSize) throws IOException { - super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize); + super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize); this.fieldEntry = fieldEntry; final RandomAccessInput addressesData = dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength()); @@ -186,8 +224,9 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - public OffHeapByteVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize); } @Override @@ -212,12 +251,28 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } }; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + SparseOffHeapVectorValues copy = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorSimilarityFunction.compare(copy.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } private int doc = -1; @@ -271,5 +326,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(byte[] query) { + return null; + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 95abedf2d87..91f97b8a41f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -20,6 +20,9 @@ package org.apache.lucene.backward_codecs.lucene94; import java.io.IOException; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; @@ -36,13 +39,20 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int byteSize; protected int lastOrd = -1; protected final float[] value; + protected final VectorSimilarityFunction vectorSimilarityFunction; - OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapFloatVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; value = new float[dimension]; + this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override @@ -81,9 +91,14 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues }; if (fieldEntry.docsWithFieldOffset() == -1) { return new DenseOffHeapVectorValues( - fieldEntry.dimension(), fieldEntry.size(), bytesSlice, byteSize); + fieldEntry.dimension(), + fieldEntry.size(), + bytesSlice, + fieldEntry.similarityFunction(), + byteSize); } else { - return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize); + return new SparseOffHeapVectorValues( + fieldEntry, vectorData, bytesSlice, fieldEntry.similarityFunction(), byteSize); } } @@ -91,8 +106,13 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, + int byteSize) { + super(dimension, size, slice, vectorSimilarityFunction, byteSize); } @Override @@ -120,14 +140,31 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public OffHeapFloatVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -141,10 +178,11 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice, + VectorSimilarityFunction vectorSimilarityFunction, int byteSize) throws IOException { - super(fieldEntry.dimension(), fieldEntry.size(), slice, byteSize); + super(fieldEntry.dimension(), fieldEntry.size(), slice, vectorSimilarityFunction, byteSize); this.fieldEntry = fieldEntry; final RandomAccessInput addressesData = dataIn.randomAccessSlice(fieldEntry.addressesOffset(), fieldEntry.addressesLength()); @@ -182,8 +220,9 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public OffHeapFloatVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize); } @Override @@ -208,12 +247,28 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } }; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues values = this.copy(); + return new VectorScorer() { + @Override + public float score() throws IOException { + return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + } + + @Override + public DocIdSetIterator iterator() { + return values; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } private int doc = -1; @@ -267,5 +322,10 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 72b7cfc82f2..ab51f935fb0 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -253,6 +253,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements + VectorEncoding.FLOAT32); } return OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -274,6 +276,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -295,6 +299,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -324,6 +330,8 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + defaultFlatVectorScorer, fieldEntry.ordToDocVectorValues, fieldEntry.vectorEncoding, fieldEntry.dimension, diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index ce07b254849..39828524d26 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -133,7 +133,10 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter { // build the graph using the temporary vector data Lucene90HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors = new Lucene90HnswVectorsReader.OffHeapFloatVectorValues( - floatVectorValues.dimension(), docIds, vectorDataInput); + floatVectorValues.dimension(), + docIds, + fieldInfo.getVectorSimilarityFunction(), + vectorDataInput); long[] offsets = new long[docIds.length]; long vectorIndexOffset = vectorIndex.getFilePointer(); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java index f3b411cee09..b914acf3fbb 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java @@ -68,4 +68,9 @@ public class TestLucene90HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testSortedIndexBytes() throws Exception { // unimplemented } + + @Override + public void testByteVectorScorerIteration() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index b58b2e21a4f..37b75250381 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -138,7 +138,11 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter { // TODO: separate random access vector values from DocIdSetIterator? Lucene91HnswVectorsReader.OffHeapFloatVectorValues offHeapVectors = new Lucene91HnswVectorsReader.OffHeapFloatVectorValues( - floatVectorValues.dimension(), docsWithField.cardinality(), null, vectorDataInput); + floatVectorValues.dimension(), + docsWithField.cardinality(), + null, + fieldInfo.getVectorSimilarityFunction(), + vectorDataInput); Lucene91OnHeapHnswGraph graph = offHeapVectors.size() == 0 ? null diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java index b8cae733722..b27a42700cb 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/TestLucene91HnswVectorsFormat.java @@ -67,4 +67,9 @@ public class TestLucene91HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testSortedIndexBytes() throws Exception { // unimplemented } + + @Override + public void testByteVectorScorerIteration() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index 4dd8f1f3054..caa8fc3da14 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -146,7 +146,10 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter { // TODO: separate random access vector values from DocIdSetIterator? OffHeapFloatVectorValues offHeapVectors = new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - floatVectorValues.dimension(), docsWithField.cardinality(), vectorDataInput); + floatVectorValues.dimension(), + docsWithField.cardinality(), + fieldInfo.getVectorSimilarityFunction(), + vectorDataInput); OnHeapHnswGraph graph = offHeapVectors.size() == 0 ? null diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java index 976191e2484..aaee5abe4ad 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/TestLucene92HnswVectorsFormat.java @@ -57,4 +57,9 @@ public class TestLucene92HnswVectorsFormat extends BaseKnnVectorsFormatTestCase public void testSortedIndexBytes() throws Exception { // unimplemented } + + @Override + public void testByteVectorScorerIteration() { + // unimplemented + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index d8cdb1739f7..9726a3b19e8 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -421,6 +421,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, + fieldInfo.getVectorSimilarityFunction(), byteSize); RandomVectorScorerSupplier scorerSupplier = defaultFlatVectorScorer.getRandomVectorScorerSupplier( @@ -437,6 +438,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, + fieldInfo.getVectorSimilarityFunction(), byteSize); RandomVectorScorerSupplier scorerSupplier = defaultFlatVectorScorer.getRandomVectorScorerSupplier( diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index c4e315a72e2..c74d34fb9ad 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -70,6 +70,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { private final IndexOutput meta, vectorData, vectorIndex; private final int M; private final int beamWidth; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final List> fields = new ArrayList<>(); private boolean finished; @@ -437,7 +438,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { OnHeapHnswGraph graph = null; int[][] vectorIndexNodeOffsets = null; if (docsWithField.cardinality() != 0) { - DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); final RandomVectorScorerSupplier scorerSupplier; switch (fieldInfo.getVectorEncoding()) { case BYTE: @@ -448,7 +448,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, - byteSize)); + byteSize, + defaultFlatVectorScorer, + fieldInfo.getVectorSimilarityFunction())); break; case FLOAT32: scorerSupplier = @@ -458,7 +460,9 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, - byteSize)); + byteSize, + defaultFlatVectorScorer, + fieldInfo.getVectorSimilarityFunction())); break; default: throw new IllegalArgumentException( @@ -667,6 +671,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { private final DocsWithFieldSet docsWithField; private final List vectors; private final HnswGraphBuilder hnswGraphBuilder; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private int lastDocID = -1; private int node = 0; @@ -697,7 +702,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); - DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 8ea9b22b35a..b8d2ad5702c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.BufferedChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; @@ -92,7 +93,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { } assert fieldEntries.containsKey(fieldName) == false; fieldEntries.put( - fieldName, new FieldEntry(dimension, vectorDataOffset, vectorDataLength, docIds)); + fieldName, + new FieldEntry( + dimension, + vectorDataOffset, + vectorDataLength, + docIds, + readState.fieldInfos.fieldInfo(fieldName).getVectorSimilarityFunction())); fieldNumber = readInt(in, FIELD_NUMBER); } SimpleTextUtil.checkFooter(in); @@ -275,7 +282,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { } private record FieldEntry( - int dimension, long vectorDataOffset, long vectorDataLength, int[] ordToDoc) { + int dimension, + long vectorDataOffset, + long vectorDataLength, + int[] ordToDoc, + VectorSimilarityFunction similarityFunction) { int size() { return ordToDoc.length; } @@ -298,6 +309,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { readAllVectors(); } + private SimpleTextFloatVectorValues(SimpleTextFloatVectorValues other) { + this.entry = other.entry; + this.in = other.in.clone(); + this.values = other.values; + this.curOrd = other.curOrd; + } + @Override public int dimension() { return entry.dimension; @@ -340,6 +358,25 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { return slowAdvance(target); } + @Override + public VectorScorer scorer(float[] target) { + SimpleTextFloatVectorValues simpleTextFloatVectorValues = + new SimpleTextFloatVectorValues(this); + return new VectorScorer() { + @Override + public float score() throws IOException { + return entry + .similarityFunction() + .compare(simpleTextFloatVectorValues.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return simpleTextFloatVectorValues; + } + }; + } + private void readAllVectors() throws IOException { for (float[] value : values) { readVector(value); @@ -379,6 +416,15 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { readAllVectors(); } + private SimpleTextByteVectorValues(SimpleTextByteVectorValues other) { + this.entry = other.entry; + this.in = other.in.clone(); + this.values = other.values; + this.binaryValue = new BytesRef(entry.dimension); + this.binaryValue.length = binaryValue.bytes.length; + this.curOrd = other.curOrd; + } + @Override public int dimension() { return entry.dimension; @@ -422,6 +468,24 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { return slowAdvance(target); } + @Override + public VectorScorer scorer(byte[] target) { + SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this); + return new VectorScorer() { + @Override + public float score() throws IOException { + return entry + .similarityFunction() + .compare(simpleTextByteVectorValues.vectorValue(), target); + } + + @Override + public DocIdSetIterator iterator() { + return simpleTextByteVectorValues; + } + }; + } + private void readAllVectors() throws IOException { for (byte[] value : values) { readVector(value); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index d79bbfd9a36..8a9b4816571 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; @@ -159,6 +160,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ @@ -216,6 +222,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } @Override @@ -354,6 +365,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } private static class BufferedByteVectorValues extends ByteVectorValues { @@ -414,5 +430,10 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter { public int advance(int target) { throw new UnsupportedOperationException(); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 078b0c10e69..8ae86d4c807 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Accountable; /** Writes vectors to an index. */ @@ -188,7 +189,6 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private int docId; VectorValuesSub current; @@ -239,6 +239,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { public int dimension() { return subs.get(0).values.dimension(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } static class MergedByteVectorValues extends ByteVectorValues { @@ -296,6 +301,11 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable { public int dimension() { return subs.get(0).values.dimension(); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 8d98a9cd1c8..b1f0972f4f9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -19,13 +19,18 @@ package org.apache.lucene.codecs.lucene95; import java.io.IOException; import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ @@ -39,14 +44,24 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; - OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapByteVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; byteBuffer = ByteBuffer.allocate(byteSize); binaryValue = byteBuffer.array(); + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; } @Override @@ -79,6 +94,8 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues } public static OffHeapByteVectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, OrdToDocDISIReaderConfiguration configuration, VectorEncoding vectorEncoding, int dimension, @@ -87,14 +104,26 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues IndexInput vectorData) throws IOException { if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) { - return new EmptyOffHeapVectorValues(dimension); + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, dimension); + return new DenseOffHeapVectorValues( + dimension, + configuration.size, + bytesSlice, + dimension, + flatVectorsScorer, + vectorSimilarityFunction); } else { return new SparseOffHeapVectorValues( - configuration, vectorData, bytesSlice, dimension, dimension); + configuration, + vectorData, + bytesSlice, + dimension, + dimension, + flatVectorsScorer, + vectorSimilarityFunction); } } @@ -106,8 +135,14 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction) { + super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction); } @Override @@ -136,13 +171,32 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + RandomVectorScorer scorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(copy.doc); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues { @@ -157,10 +211,18 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues IndexInput dataIn, IndexInput slice, int dimension, - int byteSize) + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction) throws IOException { - super(dimension, configuration.size, slice, byteSize); + super( + dimension, + configuration.size, + slice, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -200,7 +262,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dataIn, slice.clone(), dimension, byteSize); + configuration, + dataIn, + slice.clone(), + dimension, + byteSize, + flatVectorsScorer, + similarityFunction); } @Override @@ -225,12 +293,33 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues } }; } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + RandomVectorScorer scorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(copy.disi.index()); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { - public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + public EmptyOffHeapVectorValues( + int dimension, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction); } private int doc = -1; @@ -284,5 +373,10 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(byte[] query) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 0aeddaf1536..52753325f6b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -18,13 +18,18 @@ package org.apache.lucene.codecs.lucene95; import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ @@ -37,12 +42,22 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int byteSize; protected int lastOrd = -1; protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; - OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) { + OffHeapFloatVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { this.dimension = dimension; this.size = size; this.slice = slice; this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; value = new float[dimension]; } @@ -73,6 +88,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } public static OffHeapFloatVectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, OrdToDocDISIReaderConfiguration configuration, VectorEncoding vectorEncoding, int dimension, @@ -81,15 +98,27 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues IndexInput vectorData) throws IOException { if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) { - return new EmptyOffHeapVectorValues(dimension); + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); int byteSize = dimension * Float.BYTES; if (configuration.docsWithFieldOffset == -1) { - return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize); + return new DenseOffHeapVectorValues( + dimension, + configuration.size, + bytesSlice, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction); } else { return new SparseOffHeapVectorValues( - configuration, vectorData, bytesSlice, dimension, byteSize); + configuration, + vectorData, + bytesSlice, + dimension, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction); } } @@ -101,8 +130,14 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) { - super(dimension, size, slice, byteSize); + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); } @Override @@ -131,13 +166,32 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + RandomVectorScorer randomVectorScorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(copy.doc); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -152,10 +206,12 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues IndexInput dataIn, IndexInput slice, int dimension, - int byteSize) + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) throws IOException { - super(dimension, configuration.size, slice, byteSize); + super(dimension, configuration.size, slice, byteSize, flatVectorsScorer, similarityFunction); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -195,7 +251,13 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dataIn, slice.clone(), dimension, byteSize); + configuration, + dataIn, + slice.clone(), + dimension, + byteSize, + flatVectorsScorer, + similarityFunction); } @Override @@ -220,12 +282,33 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } }; } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + RandomVectorScorer randomVectorScorer = + flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(copy.disi.index()); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { - public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null, 0); + public EmptyOffHeapVectorValues( + int dimension, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); } private int doc = -1; @@ -256,17 +339,17 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - public int advance(int target) throws IOException { + public int advance(int target) { return doc = NO_MORE_DOCS; } @Override - public EmptyOffHeapVectorValues copy() throws IOException { + public EmptyOffHeapVectorValues copy() { throw new UnsupportedOperationException(); } @Override - public float[] vectorValue(int targetOrd) throws IOException { + public float[] vectorValue(int targetOrd) { throw new UnsupportedOperationException(); } @@ -279,5 +362,10 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer scorer(float[] query) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index 0311a3b0cf8..3a52afd5a9f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -185,6 +185,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { + VectorEncoding.FLOAT32); } return OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -206,6 +208,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -223,6 +227,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, @@ -241,6 +247,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.ordToDoc, fieldEntry.vectorEncoding, fieldEntry.dimension, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 0f4a4114e70..288a6ae6df9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -314,14 +314,18 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter { fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Byte.BYTES)); + fieldInfo.getVectorDimension() * Byte.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction())); case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Float.BYTES)); + fieldInfo.getVectorDimension() * Float.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction())); }; return new FlatCloseableRandomVectorScorerSupplier( () -> { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 8aaa2cca7b5..f0ba77854c6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -164,7 +165,24 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - return rawVectorsReader.getFloatVectorValues(field); + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return null; + } + final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field); + OffHeapQuantizedByteVectorValues quantizedByteVectorValues = + OffHeapQuantizedByteVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.scalarQuantizer, + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.compress, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + quantizedVectorData); + return new QuantizedVectorValues(rawVectorValues, quantizedByteVectorValues); } @Override @@ -227,6 +245,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade fieldEntry.dimension, fieldEntry.size, fieldEntry.scalarQuantizer, + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, @@ -282,6 +302,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade fieldEntry.dimension, fieldEntry.size, fieldEntry.scalarQuantizer, + fieldEntry.similarityFunction, + vectorScorer, fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, @@ -369,4 +391,56 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc); } } + + private static final class QuantizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final OffHeapQuantizedByteVectorValues quantizedVectorValues; + + QuantizedVectorValues( + FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue() throws IOException { + return rawVectorValues.vectorValue(); + } + + @Override + public int docID() { + return rawVectorValues.docID(); + } + + @Override + public int nextDoc() throws IOException { + int rawDocId = rawVectorValues.nextDoc(); + int quantizedDocId = quantizedVectorValues.nextDoc(); + assert rawDocId == quantizedDocId; + return quantizedDocId; + } + + @Override + public int advance(int target) throws IOException { + int rawDocId = rawVectorValues.advance(target); + int quantizedDocId = quantizedVectorValues.advance(target); + assert rawDocId == quantizedDocId; + return quantizedDocId; + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.vectorScorer(query); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index c4067b7fd78..278e60e0976 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -49,6 +49,7 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; @@ -526,6 +527,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite docsWithField.cardinality(), mergedQuantizationState, compress, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, quantizationDataInput))); } finally { if (success == false) { @@ -890,6 +893,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite curDoc = target; return docID(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { @@ -1013,6 +1021,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite public float getScoreCorrectionConstant() throws IOException { return current.values.getScoreCorrectionConstant(); } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } } static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { @@ -1082,6 +1095,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite return doc; } + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + private void quantize() throws IOException { if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); @@ -1182,5 +1200,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite public int advance(int target) throws IOException { return in.advance(target); } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 9659eb13187..5fee4b1cb08 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -19,10 +19,15 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; @@ -39,6 +44,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect protected final int size; protected final int numBytes; protected final ScalarQuantizer scalarQuantizer; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer vectorsScorer; protected final boolean compress; protected final IndexInput slice; @@ -85,6 +92,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect int dimension, int size, ScalarQuantizer scalarQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, boolean compress, IndexInput slice) { this.dimension = dimension; @@ -100,6 +109,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect this.byteSize = this.numBytes + Float.BYTES; byteBuffer = ByteBuffer.allocate(dimension); binaryValue = byteBuffer.array(); + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; } @Override @@ -160,22 +171,39 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect int dimension, int size, ScalarQuantizer scalarQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, boolean compress, long quantizedVectorDataOffset, long quantizedVectorDataLength, IndexInput vectorData) throws IOException { if (configuration.isEmpty()) { - return new EmptyOffHeapVectorValues(dimension); + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); } IndexInput bytesSlice = vectorData.slice( "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, bytesSlice); + return new DenseOffHeapVectorValues( + dimension, + size, + scalarQuantizer, + compress, + similarityFunction, + vectorsScorer, + bytesSlice); } else { return new SparseOffHeapVectorValues( - configuration, dimension, size, scalarQuantizer, compress, vectorData, bytesSlice); + configuration, + dimension, + size, + scalarQuantizer, + compress, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); } } @@ -192,8 +220,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect int size, ScalarQuantizer scalarQuantizer, boolean compress, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, IndexInput slice) { - super(dimension, size, scalarQuantizer, compress, slice); + super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); } @Override @@ -223,13 +253,37 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, scalarQuantizer, compress, slice.clone()); + dimension, + size, + scalarQuantizer, + compress, + similarityFunction, + vectorsScorer, + slice.clone()); } @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + RandomVectorScorer vectorScorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorScorer.score(copy.doc); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { @@ -246,9 +300,11 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect ScalarQuantizer scalarQuantizer, boolean compress, IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, IndexInput slice) throws IOException { - super(dimension, size, scalarQuantizer, compress, slice); + super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); this.configuration = configuration; this.dataIn = dataIn; this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); @@ -279,7 +335,15 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dimension, size, scalarQuantizer, compress, dataIn, slice.clone()); + configuration, + dimension, + size, + scalarQuantizer, + compress, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); } @Override @@ -304,12 +368,40 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect } }; } + + @Override + public VectorScorer vectorScorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + RandomVectorScorer vectorScorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return vectorScorer.score(copy.disi.index()); + } + + @Override + public DocIdSetIterator iterator() { + return copy; + } + }; + } } private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { - public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), false, null); + public EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super( + dimension, + 0, + new ScalarQuantizer(-1, 1, (byte) 7), + similarityFunction, + vectorsScorer, + false, + null); } private int doc = -1; @@ -363,5 +455,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect public Bits getAcceptOrds(Bits acceptDocs) { return null; } + + @Override + public VectorScorer vectorScorer(float[] target) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 4792a7d8474..d04d52b0dcf 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -19,6 +19,7 @@ package org.apache.lucene.index; import java.io.IOException; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -75,4 +76,14 @@ public abstract class ByteVectorValues extends DocIdSetIterator { + ")"); } } + + /** + * Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not + * the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and + * iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}. + * + * @param query the query vector + * @return a {@link VectorScorer} instance + */ + public abstract VectorScorer scorer(byte[] query) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index 84f9868f6ef..ca2cb1a27d4 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.FilterLeafReader.FilterTerms; import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -476,6 +477,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader { return vectorValues.size(); } + @Override + public VectorScorer scorer(float[] target) throws IOException { + return vectorValues.scorer(target); + } + /** * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or * if {@link Thread#interrupted()} returns true. @@ -543,6 +549,11 @@ public class ExitableDirectoryReader extends FilterDirectoryReader { return vectorValues.vectorValue(); } + @Override + public VectorScorer scorer(byte[] target) throws IOException { + return vectorValues.scorer(target); + } + /** * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or * if {@link Thread#interrupted()} returns true. diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 61f0157ddfd..9a5bb31b0c6 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -19,6 +19,7 @@ package org.apache.lucene.index; import java.io.IOException; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -75,4 +76,15 @@ public abstract class FloatVectorValues extends DocIdSetIterator { + ")"); } } + + /** + * Return a {@link VectorScorer} for the given query vector and the current {@link + * FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for + * this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the + * iteration of this {@link FloatVectorValues}. + * + * @param query the query vector + * @return a {@link VectorScorer} instance + */ + public abstract VectorScorer scorer(float[] query) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index d67d7daa963..f01a9bb966b 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -36,6 +36,7 @@ import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues; import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -883,6 +884,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader { public int advance(int target) throws IOException { return mergedIterator.advance(target); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } }; } @@ -937,6 +943,11 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader { public int advance(int target) throws IOException { return mergedIterator.advance(target); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index b6bc6bd5234..ae943735980 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -35,6 +35,7 @@ import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IOSupplier; @@ -266,6 +267,11 @@ public final class SortingCodecReader extends FilterCodecReader { } return docId = docsWithField.nextSetBit(target); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } private static class SortingByteVectorValues extends ByteVectorValues { @@ -320,6 +326,11 @@ public final class SortingCodecReader extends FilterCodecReader { } return docId = docsWithField.nextSetBit(target); } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } /** diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 1dec9da3042..7a4ac9d3f40 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -205,17 +205,17 @@ abstract class AbstractKnnVectorQuery extends Query { HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; ScoreDoc topDoc = queue.top(); + DocIdSetIterator vectorIterator = vectorScorer.iterator(); + DocIdSetIterator conjunction = + ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of()); int doc; - while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { // Mark results as partial if timeout is met if (queryTimeout != null && queryTimeout.shouldExit()) { relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; break; } - - boolean advanced = vectorScorer.advanceExact(doc); - assert advanced; - + assert vectorIterator.docID() == doc; float score = vectorScorer.score(); if (score > topDoc.score) { topDoc.score = score; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index 4aea90cee38..7bd21d99b30 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -19,6 +19,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import java.util.Objects; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -83,7 +84,10 @@ abstract class AbstractVectorSimilarityQuery extends Query { VectorScorer scorer = createVectorScorer(context); if (scorer == null) { return Explanation.noMatch("Not indexed as the correct vector field"); - } else if (scorer.advanceExact(doc)) { + } + DocIdSetIterator iterator = scorer.iterator(); + int docId = iterator.advance(doc); + if (docId == doc) { float score = scorer.score(); if (score >= resultSimilarity) { return Explanation.match(boost * score, "Score above threshold"); @@ -256,15 +260,15 @@ abstract class AbstractVectorSimilarityQuery extends Query { DocIdSetIterator acceptDocs, float threshold) { float[] cachedScore = new float[1]; + DocIdSetIterator vectorIterator = scorer.iterator(); + DocIdSetIterator conjunction = + ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptDocs), List.of()); DocIdSetIterator iterator = - new FilteredDocIdSetIterator(acceptDocs) { + new FilteredDocIdSetIterator(conjunction) { @Override protected boolean match(int doc) throws IOException { // Advance the scorer - if (!scorer.advanceExact(doc)) { - return false; - } - + assert doc == vectorIterator.docID(); // Compute the dot product float score = scorer.score(); cachedScore[0] = score * boost; diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java index e410ad06343..bd2190121ab 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -21,9 +21,8 @@ import java.util.Arrays; import java.util.Locale; import java.util.Objects; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.Bits; /** @@ -98,12 +97,11 @@ public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { - @SuppressWarnings("resource") - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) { + ByteVectorValues vectorValues = context.reader().getByteVectorValues(field); + if (vectorValues == null) { return null; } - return VectorScorer.create(context, fi, target); + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java index 89d029ed140..b90f12bb4d3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; /** * A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector @@ -37,26 +36,13 @@ class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource { } @Override - public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + public VectorScorer getScorer(LeafReaderContext ctx) throws IOException { final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); if (vectorValues == null) { ByteVectorValues.checkField(ctx.reader(), fieldName); - return DoubleValues.EMPTY; + return null; } - VectorSimilarityFunction function = - ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); - return new DoubleValues() { - @Override - public double doubleValue() throws IOException { - return function.compare(queryVector, vectorValues.vectorValue()); - } - - @Override - public boolean advanceExact(int doc) throws IOException { - return doc >= vectorValues.docID() - && (vectorValues.docID() == doc || vectorValues.advance(doc) == doc); - } - }; + return vectorValues.scorer(queryVector); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java index 44d06c163c0..3dc92482a77 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -21,9 +21,8 @@ import java.util.Arrays; import java.util.Locale; import java.util.Objects; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; @@ -100,11 +99,11 @@ public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { @SuppressWarnings("resource") - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { return null; } - return VectorScorer.create(context, fi, target); + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java index 8198467fc50..3bf4c0a1887 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; /** * A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector @@ -44,22 +43,32 @@ class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource { FloatVectorValues.checkField(ctx.reader(), fieldName); return DoubleValues.EMPTY; } - VectorSimilarityFunction function = - ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); return new DoubleValues() { + private final VectorScorer scorer = vectorValues.scorer(queryVector); + private final DocIdSetIterator iterator = scorer.iterator(); + @Override public double doubleValue() throws IOException { - return function.compare(queryVector, vectorValues.vectorValue()); + return scorer.score(); } @Override public boolean advanceExact(int doc) throws IOException { - return doc >= vectorValues.docID() - && (vectorValues.docID() == doc || vectorValues.advance(doc) == doc); + return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc); } }; } + @Override + public VectorScorer getScorer(LeafReaderContext ctx) throws IOException { + final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + if (vectorValues == null) { + FloatVectorValues.checkField(ctx.reader(), fieldName); + return null; + } + return vectorValues.scorer(queryVector); + } + @Override public int hashCode() { return Objects.hash(fieldName, Arrays.hashCode(queryVector)); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index ba94a243e07..e3d733e516f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -98,7 +98,12 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { - return VectorScorer.create(context, fi, target); + ByteVectorValues vectorValues = context.reader().getByteVectorValues(field); + if (vectorValues == null) { + ByteVectorValues.checkField(context.reader(), field); + return null; + } + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index e6e38192e74..91cf4474e1c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -99,7 +99,12 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { @Override VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { - return VectorScorer.create(context, fi, target); + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { + FloatVectorValues.checkField(context.reader(), field); + return null; + } + return vectorValues.scorer(target); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java index 249c7e88d87..c9b8362ae39 100644 --- a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java @@ -17,117 +17,25 @@ package org.apache.lucene.search; import java.io.IOException; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; /** * Computes the similarity score between a given query vector and different document vectors. This - * is primarily used by {@link KnnFloatVectorQuery} to run an exact, exhaustive search over the - * vectors. + * is used for exact searching and scoring + * + * @lucene.experimental */ -abstract class VectorScorer { - protected final VectorSimilarityFunction similarity; +public interface VectorScorer { /** - * Create a new vector scorer instance. + * Compute the score for the current document ID. * - * @param context the reader context - * @param fi the FieldInfo for the field containing document vectors - * @param query the query vector to compute the similarity for + * @return the score for the current document ID + * @throws IOException if an exception occurs during score computation */ - static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query) - throws IOException { - FloatVectorValues values = context.reader().getFloatVectorValues(fi.name); - if (values == null) { - FloatVectorValues.checkField(context.reader(), fi.name); - return null; - } - final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); - return new FloatVectorScorer(values, query, similarity); - } + float score() throws IOException; - static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query) - throws IOException { - ByteVectorValues values = context.reader().getByteVectorValues(fi.name); - if (values == null) { - ByteVectorValues.checkField(context.reader(), fi.name); - return null; - } - VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); - return new ByteVectorScorer(values, query, similarity); - } - - VectorScorer(VectorSimilarityFunction similarity) { - this.similarity = similarity; - } - - /** Compute the similarity score for the current document. */ - abstract float score() throws IOException; - - abstract boolean advanceExact(int doc) throws IOException; - - private static class ByteVectorScorer extends VectorScorer { - private final byte[] query; - private final ByteVectorValues values; - - protected ByteVectorScorer( - ByteVectorValues values, byte[] query, VectorSimilarityFunction similarity) { - super(similarity); - this.values = values; - this.query = query; - } - - /** - * Advance the instance to the given document ID and return true if there is a value for that - * document. - */ - @Override - public boolean advanceExact(int doc) throws IOException { - int vectorDoc = values.docID(); - if (vectorDoc < doc) { - vectorDoc = values.advance(doc); - } - return vectorDoc == doc; - } - - @Override - public float score() throws IOException { - assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned"; - return similarity.compare(query, values.vectorValue()); - } - } - - private static class FloatVectorScorer extends VectorScorer { - private final float[] query; - private final FloatVectorValues values; - - protected FloatVectorScorer( - FloatVectorValues values, float[] query, VectorSimilarityFunction similarity) { - super(similarity); - this.query = query; - this.values = values; - } - - /** - * Advance the instance to the given document ID and return true if there is a value for that - * document. - */ - @Override - public boolean advanceExact(int doc) throws IOException { - int vectorDoc = values.docID(); - if (vectorDoc < doc) { - vectorDoc = values.advance(doc); - } - return vectorDoc == doc; - } - - @Override - public float score() throws IOException { - assert values.docID() != -1 : getClass().getSimpleName() + " is not positioned"; - return similarity.compare(query, values.vectorValue()); - } - } + /** + * @return a {@link DocIdSetIterator} over the documents. + */ + DocIdSetIterator iterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java index 639e225d665..4e53d580cda 100644 --- a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java @@ -33,8 +33,26 @@ abstract class VectorSimilarityValuesSource extends DoubleValuesSource { } @Override - public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) - throws IOException; + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + VectorScorer scorer = getScorer(ctx); + if (scorer == null) { + return DoubleValues.EMPTY; + } + DocIdSetIterator iterator = scorer.iterator(); + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + return scorer.score(); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc); + } + }; + } + + protected abstract VectorScorer getScorer(LeafReaderContext ctx) throws IOException; @Override public boolean needsScores() { diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index fa029b0f5ae..c277abfc634 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -18,6 +18,8 @@ package org.apache.lucene.util.quantization; import java.io.IOException; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -25,6 +27,31 @@ import org.apache.lucene.index.ByteVectorValues; * * @lucene.experimental */ -public abstract class QuantizedByteVectorValues extends ByteVectorValues { +public abstract class QuantizedByteVectorValues extends DocIdSetIterator { public abstract float getScoreCorrectionConstant() throws IOException; + + public abstract byte[] vectorValue() throws IOException; + + /** Return the dimension of the vectors */ + public abstract int dimension(); + + /** + * Return the number of vectors for this field. + * + * @return the number of vectors returned by this iterator + */ + public abstract int size(); + + @Override + public final long cost() { + return size(); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance + */ + public abstract VectorScorer vectorScorer(float[] query) throws IOException; } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java index 6097a49151d..5432879a134 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java @@ -26,7 +26,6 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.StringField; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; @@ -44,24 +43,22 @@ public class TestVectorScorer extends LuceneTestCase { IndexReader reader = DirectoryReader.open(indexStore)) { assert reader.leaves().size() == 1; LeafReaderContext context = reader.leaves().get(0); - FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field"); final VectorScorer vectorScorer; switch (encoding) { case BYTE: - vectorScorer = VectorScorer.create(context, fieldInfo, new byte[] {1, 2}); + vectorScorer = context.reader().getByteVectorValues("field").scorer(new byte[] {1, 2}); break; case FLOAT32: - vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2}); + vectorScorer = context.reader().getFloatVectorValues("field").scorer(new float[] {1, 2}); break; default: throw new IllegalArgumentException("unexpected vector encoding: " + encoding); } + DocIdSetIterator iterator = vectorScorer.iterator(); int numDocs = 0; - for (int i = 0; i < reader.maxDoc(); i++) { - if (vectorScorer.advanceExact(i)) { - numDocs++; - } + while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + numDocs++; } assertEquals(3, numDocs); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 4770bdf98ab..7b50a4e1ce4 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -72,6 +72,7 @@ import org.apache.lucene.search.SortField; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -1128,6 +1129,11 @@ abstract class HnswGraphTestCase extends LuceneTestCase { public float[] vectorValue(int ord) { return unitVector2d(ord / (double) size, value); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } /** Returns vectors evenly distributed around the upper unit semicircle. */ @@ -1193,6 +1199,11 @@ abstract class HnswGraphTestCase extends LuceneTestCase { } return bValue; } + + @Override + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); + } } private static float[] unitVector2d(double piRadians) { diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 9d150373070..d1930569f6c 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.Set; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.tests.util.LuceneTestCase; public class TestScalarQuantizer extends LuceneTestCase { @@ -262,5 +263,10 @@ public class TestScalarQuantizer extends LuceneTestCase { curDoc = target - 1; return nextDoc(); } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index 735aa5ac10f..8c6f0f98470 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -20,10 +20,8 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.IndexSearcher; @@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; @@ -92,14 +91,10 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { return NO_RESULTS; } - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - ParentBlockJoinByteVectorScorer vectorScorer = - new ParentBlockJoinByteVectorScorer( - byteVectorValues, - acceptIterator, - parentBitSet, - query, - fi.getVectorSimilarityFunction()); + VectorScorer scorer = byteVectorValues.scorer(query); + DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer = + new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer( + acceptIterator, parentBitSet, scorer); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; @@ -177,59 +172,4 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { result = 31 * result + Arrays.hashCode(query); return result; } - - private static class ParentBlockJoinByteVectorScorer { - private final byte[] query; - private final ByteVectorValues values; - private final VectorSimilarityFunction similarity; - private final DocIdSetIterator acceptedChildrenIterator; - private final BitSet parentBitSet; - private int currentParent = -1; - private int bestChild = -1; - private float currentScore = Float.NEGATIVE_INFINITY; - - protected ParentBlockJoinByteVectorScorer( - ByteVectorValues values, - DocIdSetIterator acceptedChildrenIterator, - BitSet parentBitSet, - byte[] query, - VectorSimilarityFunction similarity) { - this.query = query; - this.values = values; - this.similarity = similarity; - this.acceptedChildrenIterator = acceptedChildrenIterator; - this.parentBitSet = parentBitSet; - } - - public int bestChild() { - return bestChild; - } - - public int nextParent() throws IOException { - int nextChild = acceptedChildrenIterator.docID(); - if (nextChild == -1) { - nextChild = acceptedChildrenIterator.nextDoc(); - } - if (nextChild == DocIdSetIterator.NO_MORE_DOCS) { - currentParent = DocIdSetIterator.NO_MORE_DOCS; - return currentParent; - } - currentScore = Float.NEGATIVE_INFINITY; - currentParent = parentBitSet.nextSetBit(nextChild); - do { - values.advance(nextChild); - float score = similarity.compare(query, values.vectorValue()); - if (score > currentScore) { - bestChild = nextChild; - currentScore = score; - } - } while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS - && nextChild < currentParent); - return currentParent; - } - - public float score() throws IOException { - return currentScore; - } - } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 8a427f467dd..513da619cae 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -19,11 +19,9 @@ package org.apache.lucene.search.join; import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.IndexSearcher; @@ -34,6 +32,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; @@ -92,14 +91,9 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery return NO_RESULTS; } - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - DiversifyingChildrenFloatVectorScorer vectorScorer = - new DiversifyingChildrenFloatVectorScorer( - floatVectorValues, - acceptIterator, - parentBitSet, - query, - fi.getVectorSimilarityFunction()); + DiversifyingChildrenVectorScorer vectorScorer = + new DiversifyingChildrenVectorScorer( + acceptIterator, parentBitSet, floatVectorValues.scorer(query)); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; @@ -178,26 +172,20 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery return result; } - private static class DiversifyingChildrenFloatVectorScorer { - private final float[] query; - private final FloatVectorValues values; - private final VectorSimilarityFunction similarity; + static class DiversifyingChildrenVectorScorer { + private final VectorScorer vectorScorer; + private final DocIdSetIterator vectorIterator; private final DocIdSetIterator acceptedChildrenIterator; private final BitSet parentBitSet; private int currentParent = -1; private int bestChild = -1; private float currentScore = Float.NEGATIVE_INFINITY; - protected DiversifyingChildrenFloatVectorScorer( - FloatVectorValues values, - DocIdSetIterator acceptedChildrenIterator, - BitSet parentBitSet, - float[] query, - VectorSimilarityFunction similarity) { - this.query = query; - this.values = values; - this.similarity = similarity; + protected DiversifyingChildrenVectorScorer( + DocIdSetIterator acceptedChildrenIterator, BitSet parentBitSet, VectorScorer vectorScorer) { this.acceptedChildrenIterator = acceptedChildrenIterator; + this.vectorScorer = vectorScorer; + this.vectorIterator = vectorScorer.iterator(); this.parentBitSet = parentBitSet; } @@ -217,8 +205,8 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery currentScore = Float.NEGATIVE_INFINITY; currentParent = parentBitSet.nextSetBit(nextChild); do { - values.advance(nextChild); - float score = similarity.compare(query, values.vectorValue()); + vectorIterator.advance(nextChild); + float score = vectorScorer.score(); if (score > currentScore) { bestChild = nextChild; currentScore = score; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 4fb5e95247d..75cb1a52e0d 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -48,10 +48,12 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.tests.util.TestUtil; @@ -718,6 +720,116 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe } } + public void testFloatVectorScorerIteration() throws Exception { + IndexWriterConfig iwc = newIndexWriterConfig(); + if (random().nextBoolean()) { + iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT))); + } + String fieldName = "field"; + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, iwc)) { + int numDoc = atLeast(100); + int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } + float[][] values = new float[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextInt(7) != 3) { + // usually index a vector value for a doc + values[i] = randomNormalizedVector(dimension); + } + add(iw, fieldName, i, values[i], similarityFunction); + if (random().nextInt(10) == 2) { + iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1)))); + } + if (random().nextInt(10) == 3) { + iw.commit(); + } + } + float[] vectorToScore = randomNormalizedVector(dimension); + try (IndexReader reader = DirectoryReader.open(iw)) { + for (LeafReaderContext ctx : reader.leaves()) { + FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + if (vectorValues == null) { + continue; + } + VectorScorer scorer = vectorValues.scorer(vectorToScore); + assertNotNull(scorer); + DocIdSetIterator iterator = scorer.iterator(); + assertSame(iterator, scorer.iterator()); + assertNotSame(iterator, scorer); + // verify scorer iteration scores are valid & iteration with vectorValues is consistent + while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + float score = scorer.score(); + assertTrue(score >= 0f); + assertEquals(iterator.docID(), vectorValues.docID()); + } + // verify that a new scorer can be obtained after iteration + VectorScorer newScorer = vectorValues.scorer(vectorToScore); + assertNotNull(newScorer); + assertNotSame(scorer, newScorer); + assertNotSame(iterator, newScorer.iterator()); + } + } + } + } + + public void testByteVectorScorerIteration() throws Exception { + IndexWriterConfig iwc = newIndexWriterConfig(); + if (random().nextBoolean()) { + iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT))); + } + String fieldName = "field"; + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, iwc)) { + int numDoc = atLeast(100); + int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } + byte[][] values = new byte[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextInt(7) != 3) { + // usually index a vector value for a doc + values[i] = randomVector8(dimension); + } + add(iw, fieldName, i, values[i], similarityFunction); + if (random().nextInt(10) == 2) { + iw.deleteDocuments(new Term("id", Integer.toString(random().nextInt(i + 1)))); + } + if (random().nextInt(10) == 3) { + iw.commit(); + } + } + byte[] vectorToScore = randomVector8(dimension); + try (IndexReader reader = DirectoryReader.open(iw)) { + for (LeafReaderContext ctx : reader.leaves()) { + ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); + if (vectorValues == null) { + continue; + } + VectorScorer scorer = vectorValues.scorer(vectorToScore); + assertNotNull(scorer); + DocIdSetIterator iterator = scorer.iterator(); + assertSame(iterator, scorer.iterator()); + assertNotSame(iterator, scorer); + // verify scorer iteration scores are valid & iteration with vectorValues is consistent + while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + float score = scorer.score(); + assertTrue(score >= 0f); + assertEquals(iterator.docID(), vectorValues.docID()); + } + // verify that a new scorer can be obtained after iteration + VectorScorer newScorer = vectorValues.scorer(vectorToScore); + assertNotNull(newScorer); + assertNotSame(scorer, newScorer); + assertNotSame(iterator, newScorer.iterator()); + } + } + } + } + protected VectorSimilarityFunction randomSimilarity() { return VectorSimilarityFunction.values()[ random().nextInt(VectorSimilarityFunction.values().length)];