First-class random access API for KnnVectorValues (#13779)

This commit is contained in:
Michael Sokolov 2024-09-28 09:14:01 -04:00 committed by GitHub
parent 7b4b0238d7
commit 6053e1e313
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
83 changed files with 2103 additions and 2359 deletions

View File

@ -18,10 +18,10 @@
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.IOException;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
@ -29,7 +29,7 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues.Floats {
public class Word2VecModel extends FloatVectorValues {
private final int dictionarySize;
private final int vectorDimension;

View File

@ -22,10 +22,10 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder {
private final Lucene90NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues.Floats vectorValues;
private final FloatVectorValues vectorValues;
private final SplittableRandom random;
private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private final RandomAccessVectorValues.Floats buildVectors;
private final FloatVectorValues buildVectors;
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene90HnswGraphBuilder(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
@ -97,14 +97,14 @@ public final class Lucene90HnswGraphBuilder {
}
/**
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
* Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies
* enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
*
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
public Lucene90OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -230,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
float[] candidate,
float score,
Lucene90NeighborArray neighbors,
RandomAccessVectorValues.Floats vectorValues)
FloatVectorValues vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {

View File

@ -20,7 +20,6 @@ package org.apache.lucene.backward_codecs.lucene90;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
@ -34,7 +33,6 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
@ -44,7 +42,6 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
@ -355,8 +352,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class OffHeapFloatVectorValues extends FloatVectorValues {
final int dimension;
final int[] ordToDoc;
@ -367,9 +363,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
final float[] value;
final VectorSimilarityFunction similarityFunction;
int ord = -1;
int doc = -1;
OffHeapFloatVectorValues(
int dimension,
int[] ordToDoc,
@ -394,42 +387,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return ordToDoc.length;
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(ord);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
if (++ord >= size()) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDoc[ord];
}
return doc;
}
@Override
public int advance(int target) {
assert docID() < target;
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
assert ord <= ordToDoc.length;
if (ord == ordToDoc.length) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDoc[ord];
}
return doc;
}
@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
@ -446,21 +403,32 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return value;
}
@Override
public int ordToDoc(int ord) {
return ordToDoc[ord];
}
@Override
public DocIndexIterator iterator() {
return createSparseIterator();
}
@Override
public VectorScorer scorer(float[] target) {
if (size() == 0) {
return null;
}
OffHeapFloatVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
return values.similarityFunction.compare(values.vectorValue(iterator.index()), target);
}
@Override
public DocIdSetIterator iterator() {
return values;
public DocIndexIterator iterator() {
return iterator;
}
};
}

View File

@ -23,12 +23,12 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.SplittableRandom;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
float[] query,
int topK,
int numSeed,
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,

View File

@ -46,7 +46,6 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
@ -398,8 +397,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class OffHeapFloatVectorValues extends FloatVectorValues {
private final int dimension;
private final int size;
@ -410,9 +408,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
private final float[] value;
private final VectorSimilarityFunction similarityFunction;
private int ord = -1;
private int doc = -1;
OffHeapFloatVectorValues(
int dimension,
int size,
@ -439,49 +434,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return size;
}
@Override
public float[] vectorValue() throws IOException {
dataIn.seek((long) ord * byteSize);
dataIn.readFloats(value, 0, value.length);
return value;
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
if (++ord >= size) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDocOperator.applyAsInt(ord);
}
return doc;
}
@Override
public int advance(int target) {
assert docID() < target;
if (ordToDoc == null) {
ord = target;
} else {
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
}
if (ord < size) {
doc = ordToDocOperator.applyAsInt(ord);
} else {
doc = NO_MORE_DOCS;
}
return doc;
}
@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(
@ -495,21 +447,32 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return value;
}
@Override
public int ordToDoc(int ord) {
return ordToDocOperator.applyAsInt(ord);
}
@Override
public DocIndexIterator iterator() {
return createSparseIterator();
}
@Override
public VectorScorer scorer(float[] target) {
if (size == 0) {
return null;
}
OffHeapFloatVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
return values.similarityFunction.compare(values.vectorValue(iterator.index()), target);
}
@Override
public DocIdSetIterator iterator() {
return values;
return iterator;
}
};
}

View File

@ -26,12 +26,10 @@ import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
abstract class OffHeapFloatVectorValues extends FloatVectorValues {
protected final int dimension;
protected final int size;
@ -95,8 +93,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension,
int size,
@ -105,35 +101,16 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
super(dimension, size, vectorSimilarityFunction, slice);
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone());
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
@ -142,15 +119,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
return values.vectorSimilarityFunction.compare(
values.vectorValue(iterator.index()), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
return iterator;
}
};
}
@ -186,33 +165,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
fieldEntry.size());
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
fieldEntry, dataIn, vectorSimilarityFunction, slice.clone());
}
@Override
public DocIndexIterator iterator() {
return IndexedDISI.asDocIndexIterator(disi);
}
@Override
public int ordToDoc(int ord) {
return (int) ordToDoc.get(ord);
@ -239,15 +202,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
return values.vectorSimilarityFunction.compare(
values.vectorValue(iterator.index()), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
return iterator;
}
};
}
@ -259,8 +224,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
super(dimension, 0, VectorSimilarityFunction.COSINE, null);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
@ -271,26 +234,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
return 0;
}
@Override
public float[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
}
@Override
public OffHeapFloatVectorValues copy() throws IOException {
throw new UnsupportedOperationException();

View File

@ -28,12 +28,10 @@ import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues.Bytes {
abstract class OffHeapByteVectorValues extends ByteVectorValues {
protected final int dimension;
protected final int size;
@ -108,8 +106,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension,
int size,
@ -119,36 +115,17 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
@ -157,15 +134,16 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public VectorScorer scorer(byte[] query) throws IOException {
DenseOffHeapVectorValues copy = this.copy();
DocIndexIterator iterator = copy.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
return vectorSimilarityFunction.compare(copy.vectorValue(iterator.index()), query);
}
@Override
public DocIdSetIterator iterator() {
return copy;
return iterator;
}
};
}
@ -202,27 +180,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
fieldEntry.size());
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
@ -234,6 +191,11 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
return (int) ordToDoc.get(ord);
}
@Override
public DocIndexIterator iterator() {
return fromDISI(disi);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) {
@ -255,15 +217,16 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public VectorScorer scorer(byte[] query) throws IOException {
SparseOffHeapVectorValues copy = this.copy();
IndexedDISI disi = copy.disi;
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
return vectorSimilarityFunction.compare(copy.vectorValue(disi.index()), query);
}
@Override
public DocIdSetIterator iterator() {
return copy;
return disi;
}
};
}
@ -275,8 +238,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
@ -287,26 +248,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
return 0;
}
@Override
public byte[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
}
@Override
public OffHeapByteVectorValues copy() throws IOException {
throw new UnsupportedOperationException();

View File

@ -26,12 +26,10 @@ import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
abstract class OffHeapFloatVectorValues extends FloatVectorValues {
protected final int dimension;
protected final int size;
@ -104,8 +102,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension,
int size,
@ -115,36 +111,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
@ -153,15 +130,18 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
return values.vectorSimilarityFunction.compare(
values.vectorValue(iterator.index()), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
return iterator;
}
};
}
@ -198,33 +178,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
fieldEntry.size());
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize);
}
@Override
public DocIndexIterator iterator() {
return IndexedDISI.asDocIndexIterator(disi);
}
@Override
public int ordToDoc(int ord) {
return (int) ordToDoc.get(ord);
@ -251,15 +215,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
return values.vectorSimilarityFunction.compare(
values.vectorValue(iterator.index()), query);
}
@Override
public DocIdSetIterator iterator() {
return values;
return iterator;
}
};
}
@ -271,8 +237,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
@ -283,26 +247,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
return 0;
}
@Override
public float[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
}
@Override
public OffHeapFloatVectorValues copy() throws IOException {
throw new UnsupportedOperationException();

View File

@ -29,13 +29,13 @@ import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Writes vector values and knn graphs to index segments.
@ -188,12 +188,13 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
int count = 0;
ByteBuffer binaryVector =
ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
KnnVectorValues.DocIndexIterator iter = vectors.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
float[] vectorValue = vectors.vectorValue();
float[] vectorValue = vectors.vectorValue(iter.index());
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), binaryVector.limit());
docIds[count] = docV;
docIds[count++] = docV;
}
if (docIds.length > count) {
@ -234,7 +235,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
private void writeGraph(
IndexOutput graphData,
RandomAccessVectorValues.Floats vectorValues,
FloatVectorValues vectorValues,
VectorSimilarityFunction similarityFunction,
long graphDataOffset,
long[] offsets,

View File

@ -12,7 +12,7 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* limIndexedDISIitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene90;

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
@ -32,7 +33,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
@ -57,7 +57,7 @@ public final class Lucene91HnswGraphBuilder {
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues.Floats vectorValues;
private final FloatVectorValues vectorValues;
private final SplittableRandom random;
private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
@ -68,7 +68,7 @@ public final class Lucene91HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private RandomAccessVectorValues.Floats buildVectors;
private FloatVectorValues buildVectors;
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
@ -83,7 +83,7 @@ public final class Lucene91HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene91HnswGraphBuilder(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
@ -113,14 +113,14 @@ public final class Lucene91HnswGraphBuilder {
}
/**
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
* Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies
* enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
*
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
* accessor for the vectors
*/
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
public Lucene91OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -254,7 +254,7 @@ public final class Lucene91HnswGraphBuilder {
float[] candidate,
float score,
Lucene91NeighborArray neighbors,
RandomAccessVectorValues.Floats vectorValues)
FloatVectorValues vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {

View File

@ -17,8 +17,6 @@
package org.apache.lucene.backward_codecs.lucene91;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@ -30,6 +28,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
@ -37,7 +36,6 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Writes vector values and knn graphs to index segments.
@ -183,9 +181,10 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
ByteBuffer binaryVector =
ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = vectors.iterator();
for (int docV = iter.nextDoc(); docV != DocIdSetIterator.NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
float[] vectorValue = vectors.vectorValue();
float[] vectorValue = vectors.vectorValue(iter.index());
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), binaryVector.limit());
docsWithField.add(docV);
@ -243,7 +242,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private Lucene91OnHeapHnswGraph writeGraph(
RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction)
FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph

View File

@ -18,7 +18,6 @@
package org.apache.lucene.backward_codecs.lucene92;
import static org.apache.lucene.backward_codecs.lucene92.Lucene92RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
@ -33,6 +32,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
@ -43,7 +43,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
@ -190,9 +189,12 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
ByteBuffer binaryVector =
ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
KnnVectorValues.DocIndexIterator iterator = vectors.iterator();
for (int docV = iterator.nextDoc();
docV != DocIdSetIterator.NO_MORE_DOCS;
docV = iterator.nextDoc()) {
// write vector
float[] vectorValue = vectors.vectorValue();
float[] vectorValue = vectors.vectorValue(iterator.index());
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), binaryVector.limit());
docsWithField.add(docV);
@ -277,7 +279,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction)
FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
// build graph

View File

@ -36,6 +36,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
@ -52,7 +53,6 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
@ -216,9 +216,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
@ -556,9 +554,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
final DirectMonotonicWriter ordToDocWriter =
DirectMonotonicWriter.getInstance(meta, vectorData, count, DIRECT_MONOTONIC_BLOCK_SHIFT);
DocIdSetIterator iterator = docsWithField.iterator();
for (int doc = iterator.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = iterator.nextDoc()) {
for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) {
ordToDocWriter.add(doc);
}
ordToDocWriter.finish();
@ -590,11 +586,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
byte[] binaryValue = byteVectorValues.vectorValue();
byte[] binaryValue = byteVectorValues.vectorValue(iter.index());
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
@ -608,14 +603,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private static DocsWithFieldSet writeVectorData(
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
ByteBuffer binaryVector =
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
float[] vectorValue = floatVectorValues.vectorValue();
float[] vectorValue = floatVectorValues.vectorValue(iter.index());
binaryVector.asFloatBuffer().put(vectorValue);
output.writeBytes(binaryVector.array(), binaryVector.limit());
docsWithField.add(docV);
@ -672,11 +666,11 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
case BYTE ->
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes((List<byte[]>) vectors, dim));
ByteVectorValues.fromBytes((List<byte[]>) vectors, dim));
case FLOAT32 ->
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats((List<float[]>) vectors, dim));
FloatVectorValues.fromFloats((List<float[]>) vectors, dim));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);

View File

@ -39,6 +39,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
@ -56,7 +57,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
@ -221,9 +221,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
@ -482,18 +480,18 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
}
}
DocIdSetIterator mergedVectorIterator = null;
KnnVectorValues mergedVectorValues = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
mergedVectorIterator =
mergedVectorValues =
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 ->
mergedVectorIterator =
mergedVectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
graph =
merger.merge(
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
mergedVectorValues, segmentWriteState.infoStream, docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@ -636,14 +634,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) {
// write vector
byte[] binaryValue = byteVectorValues.vectorValue();
byte[] binaryValue = byteVectorValues.vectorValue(iter.index());
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
docsWithField.add(docId);
}
return docsWithField;
}
@ -657,11 +654,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
ByteBuffer buffer =
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
float[] value = floatVectorValues.vectorValue();
float[] value = floatVectorValues.vectorValue(iter.index());
buffer.asFloatBuffer().put(value);
output.writeBytes(buffer.array(), buffer.limit());
docsWithField.add(docV);
@ -718,11 +714,11 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
case BYTE ->
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes((List<byte[]>) vectors, dim));
ByteVectorValues.fromBytes((List<byte[]>) vectors, dim));
case FLOAT32 ->
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats((List<float[]>) vectors, dim));
FloatVectorValues.fromFloats((List<float[]>) vectors, dim));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);

View File

@ -52,6 +52,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.LogByteSizeMergePolicy;
import org.apache.lucene.index.MultiBits;
@ -477,10 +478,14 @@ public class TestBasicBackwardsCompatibility extends BackwardsCompatibilityTestB
FloatVectorValues values = ctx.reader().getFloatVectorValues(KNN_VECTOR_FIELD);
if (values != null) {
assertEquals(KNN_VECTOR_FIELD_TYPE.vectorDimension(), values.dimension());
for (int doc = values.nextDoc(); doc != NO_MORE_DOCS; doc = values.nextDoc()) {
KnnVectorValues.DocIndexIterator it = values.iterator();
for (int doc = it.nextDoc(); doc != NO_MORE_DOCS; doc = it.nextDoc()) {
float[] expectedVector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * cnt};
assertArrayEquals(
"vectors do not match for doc=" + cnt, expectedVector, values.vectorValue(), 0);
"vectors do not match for doc=" + cnt,
expectedVector,
values.vectorValue(it.index()),
0);
cnt++;
}
}

View File

@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
@ -32,7 +33,6 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.openjdk.jmh.annotations.*;
@ -55,7 +55,7 @@ public class VectorScorerBenchmark {
Directory dir;
IndexInput in;
RandomAccessVectorValues vectorValues;
KnnVectorValues vectorValues;
byte[] vec1, vec2;
RandomVectorScorer scorer;
@ -95,7 +95,7 @@ public class VectorScorerBenchmark {
return scorer.score(1);
}
static RandomAccessVectorValues vectorValues(
static KnnVectorValues vectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim);
@ -105,23 +105,19 @@ public class VectorScorerBenchmark {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) {
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) {
throw new UnsupportedOperationException();
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target) {
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) {
throw new UnsupportedOperationException();
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target) {
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) {
throw new UnsupportedOperationException();
}
}

View File

@ -19,10 +19,11 @@ package org.apache.lucene.codecs.bitvectors;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
@ -30,45 +31,39 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
public class FlatBitVectorsScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
assert vectorValues instanceof ByteVectorValues;
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
return new BitRandomVectorScorerSupplier(byteVectorValues);
}
throw new IllegalArgumentException(
"vectorValues must be an instance of RandomAccessVectorValues.Bytes");
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
throw new IllegalArgumentException("bit vectors do not support float[] targets");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
assert vectorValues instanceof ByteVectorValues;
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
return new BitRandomVectorScorer(byteVectorValues, target);
}
throw new IllegalArgumentException(
"vectorValues must be an instance of RandomAccessVectorValues.Bytes");
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
}
static class BitRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final ByteVectorValues vectorValues;
private final int bitDimensions;
private final byte[] query;
BitRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
BitRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) {
this.query = query;
this.bitDimensions = vectorValues.dimension() * Byte.SIZE;
this.vectorValues = vectorValues;
@ -97,12 +92,11 @@ public class FlatBitVectorsScorer implements FlatVectorsScorer {
}
static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;
protected final ByteVectorValues vectorValues;
protected final ByteVectorValues vectorValues1;
protected final ByteVectorValues vectorValues2;
public BitRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues)
throws IOException {
public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();

View File

@ -192,8 +192,8 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
int doc;
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
for (int ord = 0; ord < values.size(); ord++) {
int doc = values.ordToDoc(ord);
if (acceptDocs != null && acceptDocs.get(doc) == false) {
continue;
}
@ -202,7 +202,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
break;
}
float[] vector = values.vectorValue();
float[] vector = values.vectorValue(ord);
float score = vectorSimilarity.compare(vector, target);
knnCollector.collect(doc, score);
knnCollector.incVisitedCount(1);
@ -223,8 +223,8 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
int doc;
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
for (int ord = 0; ord < values.size(); ord++) {
int doc = values.ordToDoc(ord);
if (acceptDocs != null && acceptDocs.get(doc) == false) {
continue;
}
@ -233,7 +233,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
break;
}
byte[] vector = values.vectorValue();
byte[] vector = values.vectorValue(ord);
float score = vectorSimilarity.compare(vector, target);
knnCollector.collect(doc, score);
knnCollector.incVisitedCount(1);
@ -327,35 +327,18 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public float[] vectorValue() {
return values[curOrd];
public float[] vectorValue(int ord) {
return values[ord];
}
@Override
public int docID() {
if (curOrd == -1) {
return -1;
} else if (curOrd >= entry.size()) {
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
// immediately afterward should also return NO_MORE_DOCS
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
return NO_MORE_DOCS;
}
return entry.ordToDoc[curOrd];
public int ordToDoc(int ord) {
return entry.ordToDoc[ord];
}
@Override
public int nextDoc() throws IOException {
if (++curOrd < entry.size()) {
return docID();
}
return NO_MORE_DOCS;
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
public DocIndexIterator iterator() {
return createSparseIterator();
}
@Override
@ -365,17 +348,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
SimpleTextFloatVectorValues simpleTextFloatVectorValues =
new SimpleTextFloatVectorValues(this);
DocIndexIterator iterator = simpleTextFloatVectorValues.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
int ord = iterator.index();
return entry
.similarityFunction()
.compare(simpleTextFloatVectorValues.vectorValue(), target);
.compare(simpleTextFloatVectorValues.vectorValue(ord), target);
}
@Override
public DocIdSetIterator iterator() {
return simpleTextFloatVectorValues;
return iterator;
}
};
}
@ -397,6 +382,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
value[i] = Float.parseFloat(floatStrings[i]);
}
}
@Override
public SimpleTextFloatVectorValues copy() {
return this;
}
}
private static class SimpleTextByteVectorValues extends ByteVectorValues {
@ -439,36 +429,14 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public byte[] vectorValue() {
binaryValue.bytes = values[curOrd];
public byte[] vectorValue(int ord) {
binaryValue.bytes = values[ord];
return binaryValue.bytes;
}
@Override
public int docID() {
if (curOrd == -1) {
return -1;
} else if (curOrd >= entry.size()) {
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
// immediately afterward should also return NO_MORE_DOCS
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
return NO_MORE_DOCS;
}
return entry.ordToDoc[curOrd];
}
@Override
public int nextDoc() throws IOException {
if (++curOrd < entry.size()) {
return docID();
}
return NO_MORE_DOCS;
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
public DocIndexIterator iterator() {
return createSparseIterator();
}
@Override
@ -478,16 +446,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this);
return new VectorScorer() {
DocIndexIterator it = simpleTextByteVectorValues.iterator();
@Override
public float score() throws IOException {
int ord = it.index();
return entry
.similarityFunction()
.compare(simpleTextByteVectorValues.vectorValue(), target);
.compare(simpleTextByteVectorValues.vectorValue(ord), target);
}
@Override
public DocIdSetIterator iterator() {
return simpleTextByteVectorValues;
return it;
}
};
}
@ -509,6 +480,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
value[i] = (byte) Float.parseFloat(floatStrings[i]);
}
}
@Override
public SimpleTextByteVectorValues copy() {
return this;
}
}
private int readInt(IndexInput in, BytesRef field) throws IOException {

View File

@ -28,6 +28,7 @@ import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
@ -77,19 +78,18 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
throws IOException {
long vectorDataOffset = vectorData.getFilePointer();
List<Integer> docIds = new ArrayList<>();
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
writeFloatVectorValue(floatVectorValues);
docIds.add(docV);
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) {
writeFloatVectorValue(floatVectorValues, iter.index());
docIds.add(docId);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
}
private void writeFloatVectorValue(FloatVectorValues vectors) throws IOException {
private void writeFloatVectorValue(FloatVectorValues vectors, int ord) throws IOException {
// write vector value
float[] value = vectors.vectorValue();
float[] value = vectors.vectorValue(ord);
assert value.length == vectors.dimension();
write(vectorData, Arrays.toString(value));
newline(vectorData);
@ -100,19 +100,18 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
throws IOException {
long vectorDataOffset = vectorData.getFilePointer();
List<Integer> docIds = new ArrayList<>();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
writeByteVectorValue(byteVectorValues);
KnnVectorValues.DocIndexIterator it = byteVectorValues.iterator();
for (int docV = it.nextDoc(); docV != NO_MORE_DOCS; docV = it.nextDoc()) {
writeByteVectorValue(byteVectorValues, it.index());
docIds.add(docV);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
}
private void writeByteVectorValue(ByteVectorValues vectors) throws IOException {
private void writeByteVectorValue(ByteVectorValues vectors, int ord) throws IOException {
// write vector value
byte[] value = vectors.vectorValue();
byte[] value = vectors.vectorValue(ord);
assert value.length == vectors.dimension();
write(vectorData, Arrays.toString(value));
newline(vectorData);

View File

@ -20,14 +20,16 @@ package org.apache.lucene.codecs;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.index.SortingCodecReader;
import org.apache.lucene.index.SortingCodecReader.SortingValuesIterator;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
@ -80,24 +82,26 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
case FLOAT32:
BufferedFloatVectorValues bufferedFloatVectorValues =
new BufferedFloatVectorValues(
fieldData.docsWithField,
(List<float[]>) fieldData.vectors,
fieldData.fieldInfo.getVectorDimension());
fieldData.fieldInfo.getVectorDimension(),
fieldData.docsWithField);
FloatVectorValues floatVectorValues =
sortMap != null
? new SortingFloatVectorValues(bufferedFloatVectorValues, sortMap)
? new SortingFloatVectorValues(
bufferedFloatVectorValues, fieldData.docsWithField, sortMap)
: bufferedFloatVectorValues;
writeField(fieldData.fieldInfo, floatVectorValues, maxDoc);
break;
case BYTE:
BufferedByteVectorValues bufferedByteVectorValues =
new BufferedByteVectorValues(
fieldData.docsWithField,
(List<byte[]>) fieldData.vectors,
fieldData.fieldInfo.getVectorDimension());
fieldData.fieldInfo.getVectorDimension(),
fieldData.docsWithField);
ByteVectorValues byteVectorValues =
sortMap != null
? new SortingByteVectorValues(bufferedByteVectorValues, sortMap)
? new SortingByteVectorValues(
bufferedByteVectorValues, fieldData.docsWithField, sortMap)
: bufferedByteVectorValues;
writeField(fieldData.fieldInfo, byteVectorValues, maxDoc);
break;
@ -107,125 +111,77 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
private static class SortingFloatVectorValues extends FloatVectorValues {
private final BufferedFloatVectorValues randomAccess;
private final int[] docIdOffsets;
private int docId = -1;
private final BufferedFloatVectorValues delegate;
private final Supplier<SortingValuesIterator> iteratorSupplier;
SortingFloatVectorValues(BufferedFloatVectorValues delegate, Sorter.DocMap sortMap)
SortingFloatVectorValues(
BufferedFloatVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap)
throws IOException {
this.randomAccess = delegate.copy();
this.docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
this.delegate = delegate.copy();
iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap);
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
while (docId < docIdOffsets.length - 1) {
++docId;
if (docIdOffsets[docId] != 0) {
return docId;
}
}
docId = NO_MORE_DOCS;
return docId;
}
@Override
public float[] vectorValue() throws IOException {
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
public float[] vectorValue(int ord) throws IOException {
return delegate.vectorValue(ord);
}
@Override
public int dimension() {
return randomAccess.dimension();
return delegate.dimension();
}
@Override
public int size() {
return randomAccess.size();
return delegate.size();
}
@Override
public int advance(int target) throws IOException {
public SortingFloatVectorValues copy() {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return iteratorSupplier.get();
}
}
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
/** Sorting ByteVectorValues that iterate over documents in the order of the provided sortMap */
private static class SortingByteVectorValues extends ByteVectorValues {
private final BufferedByteVectorValues randomAccess;
private final int[] docIdOffsets;
private int docId = -1;
private final BufferedByteVectorValues delegate;
private final Supplier<SortingValuesIterator> iteratorSupplier;
SortingByteVectorValues(BufferedByteVectorValues delegate, Sorter.DocMap sortMap)
SortingByteVectorValues(
BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap)
throws IOException {
this.randomAccess = delegate.copy();
this.docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
this.delegate = delegate;
iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap);
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
while (docId < docIdOffsets.length - 1) {
++docId;
if (docIdOffsets[docId] != 0) {
return docId;
}
}
docId = NO_MORE_DOCS;
return docId;
}
@Override
public byte[] vectorValue() throws IOException {
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
public byte[] vectorValue(int ord) throws IOException {
return delegate.vectorValue(ord);
}
@Override
public int dimension() {
return randomAccess.dimension();
return delegate.dimension();
}
@Override
public int size() {
return randomAccess.size();
return delegate.size();
}
@Override
public int advance(int target) throws IOException {
public SortingByteVectorValues copy() {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return iteratorSupplier.get();
}
}
@ -296,7 +252,9 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
@Override
public final long ramBytesUsed() {
if (vectors.size() == 0) return 0;
if (vectors.isEmpty()) {
return 0;
}
return docsWithField.ramBytesUsed()
+ vectors.size()
* (long)
@ -307,25 +265,18 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
private static class BufferedFloatVectorValues extends FloatVectorValues {
final DocsWithFieldSet docsWithField;
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
final List<float[]> vectors;
final int dimension;
private final DocIdSet docsWithField;
private final DocIndexIterator iterator;
DocIdSetIterator docsWithFieldIter;
int ord = -1;
BufferedFloatVectorValues(
DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
this.docsWithField = docsWithField;
BufferedFloatVectorValues(List<float[]> vectors, int dimension, DocIdSet docsWithField)
throws IOException {
this.vectors = vectors;
this.dimension = dimension;
docsWithFieldIter = docsWithField.iterator();
}
public BufferedFloatVectorValues copy() {
return new BufferedFloatVectorValues(docsWithField, vectors, dimension);
this.docsWithField = docsWithField;
this.iterator = fromDISI(docsWithField.iterator());
}
@Override
@ -339,58 +290,39 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
public float[] vectorValue() {
return vectors.get(ord);
public int ordToDoc(int ord) {
return ord;
}
float[] vectorValue(int targetOrd) {
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public int docID() {
return docsWithFieldIter.docID();
public DocIndexIterator iterator() {
return iterator;
}
@Override
public int nextDoc() throws IOException {
int docID = docsWithFieldIter.nextDoc();
if (docID != NO_MORE_DOCS) {
++ord;
}
return docID;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
public BufferedFloatVectorValues copy() throws IOException {
return new BufferedFloatVectorValues(vectors, dimension, docsWithField);
}
}
private static class BufferedByteVectorValues extends ByteVectorValues {
final DocsWithFieldSet docsWithField;
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
final List<byte[]> vectors;
final int dimension;
private final DocIdSet docsWithField;
private final DocIndexIterator iterator;
DocIdSetIterator docsWithFieldIter;
int ord = -1;
BufferedByteVectorValues(DocsWithFieldSet docsWithField, List<byte[]> vectors, int dimension) {
this.docsWithField = docsWithField;
BufferedByteVectorValues(List<byte[]> vectors, int dimension, DocIdSet docsWithField)
throws IOException {
this.vectors = vectors;
this.dimension = dimension;
docsWithFieldIter = docsWithField.iterator();
}
public BufferedByteVectorValues copy() {
return new BufferedByteVectorValues(docsWithField, vectors, dimension);
this.docsWithField = docsWithField;
iterator = fromDISI(docsWithField.iterator());
}
@Override
@ -404,36 +336,18 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
public byte[] vectorValue() {
return vectors.get(ord);
}
byte[] vectorValue(int targetOrd) {
public byte[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public int docID() {
return docsWithFieldIter.docID();
public DocIndexIterator iterator() {
return iterator;
}
@Override
public int nextDoc() throws IOException {
int docID = docsWithFieldIter.nextDoc();
if (docID != NO_MORE_DOCS) {
++ord;
}
return docID;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
public BufferedByteVectorValues copy() throws IOException {
return new BufferedByteVectorValues(vectors, dimension, docsWithField);
}
}
}

View File

@ -30,6 +30,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
@ -55,28 +56,26 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
@SuppressWarnings("unchecked")
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case BYTE:
case BYTE -> {
KnnFieldVectorsWriter<byte[]> byteWriter =
(KnnFieldVectorsWriter<byte[]>) addField(fieldInfo);
ByteVectorValues mergedBytes =
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
for (int doc = mergedBytes.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = mergedBytes.nextDoc()) {
byteWriter.addValue(doc, mergedBytes.vectorValue());
KnnVectorValues.DocIndexIterator iter = mergedBytes.iterator();
for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) {
byteWriter.addValue(doc, mergedBytes.vectorValue(iter.index()));
}
break;
case FLOAT32:
}
case FLOAT32 -> {
KnnFieldVectorsWriter<float[]> floatWriter =
(KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
FloatVectorValues mergedFloats =
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
for (int doc = mergedFloats.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = mergedFloats.nextDoc()) {
floatWriter.addValue(doc, mergedFloats.vectorValue());
KnnVectorValues.DocIndexIterator iter = mergedFloats.iterator();
for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) {
floatWriter.addValue(doc, mergedFloats.vectorValue(iter.index()));
}
}
break;
}
}
@ -117,32 +116,44 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
private static class FloatVectorValuesSub extends DocIDMerger.Sub {
final FloatVectorValues values;
final KnnVectorValues.DocIndexIterator iterator;
FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
this.iterator = values.iterator();
assert iterator.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
return values.nextDoc();
return iterator.nextDoc();
}
public int index() {
return iterator.index();
}
}
private static class ByteVectorValuesSub extends DocIDMerger.Sub {
final ByteVectorValues values;
final KnnVectorValues.DocIndexIterator iterator;
ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
iterator = values.iterator();
assert iterator.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
return values.nextDoc();
return iterator.nextDoc();
}
int index() {
return iterator.index();
}
}
@ -287,7 +298,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
private final List<FloatVectorValuesSub> subs;
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
private final int size;
private int docId;
private int docId = -1;
private int lastOrd = -1;
FloatVectorValuesSub current;
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
@ -299,33 +311,57 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override
public DocIndexIterator iterator() {
return new DocIndexIterator() {
private int index = -1;
@Override
public int docID() {
return docId;
}
@Override
public int index() {
return index;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
index = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
++index;
}
return docId;
}
@Override
public float[] vectorValue() throws IOException {
return current.values.vectorValue();
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
public long cost() {
return size;
}
};
}
@Override
public float[] vectorValue(int ord) throws IOException {
if (ord != lastOrd + 1) {
throw new IllegalStateException(
"only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd);
} else {
lastOrd = ord;
}
return current.values.vectorValue(current.index());
}
@Override
@ -338,10 +374,20 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
return subs.get(0).values.dimension();
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
@Override
public FloatVectorValues copy() {
throw new UnsupportedOperationException();
}
}
static class MergedByteVectorValues extends ByteVectorValues {
@ -349,7 +395,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
private final int size;
private int docId;
private int lastOrd = -1;
private int docId = -1;
ByteVectorValuesSub current;
private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
@ -361,35 +408,59 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override
public byte[] vectorValue() throws IOException {
return current.values.vectorValue();
public byte[] vectorValue(int ord) throws IOException {
if (ord != lastOrd + 1) {
throw new IllegalStateException(
"only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd);
} else {
lastOrd = ord;
}
return current.values.vectorValue(current.index());
}
@Override
public DocIndexIterator iterator() {
return new DocIndexIterator() {
private int index = -1;
@Override
public int docID() {
return docId;
}
@Override
public int index() {
return index;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
index = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
++index;
}
return docId;
}
@Override
public int advance(int target) {
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return size;
}
};
}
@Override
public int size() {
return size;
@ -400,10 +471,20 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
return subs.get(0).values.dimension();
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
@Override
public ByteVectorValues copy() {
throw new UnsupportedOperationException();
}
}
}
}

View File

@ -18,8 +18,10 @@
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
@ -34,24 +36,26 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof RandomAccessVectorValues.Floats floatVectorValues) {
return new FloatScoringSupplier(floatVectorValues, similarityFunction);
} else if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
return new ByteScoringSupplier(byteVectorValues, similarityFunction);
switch (vectorValues.getEncoding()) {
case FLOAT32 -> {
return new FloatScoringSupplier((FloatVectorValues) vectorValues, similarityFunction);
}
case BYTE -> {
return new ByteScoringSupplier((ByteVectorValues) vectorValues, similarityFunction);
}
}
throw new IllegalArgumentException(
"vectorValues must be an instance of RandomAccessVectorValues.Floats or RandomAccessVectorValues.Bytes");
"vectorValues must be an instance of FloatVectorValues or ByteVectorValues, got a "
+ vectorValues.getClass().getName());
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Floats;
assert vectorValues instanceof FloatVectorValues;
if (target.length != vectorValues.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
@ -59,17 +63,14 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
+ " differs from field dimension: "
+ vectorValues.dimension());
}
return new FloatVectorScorer(
(RandomAccessVectorValues.Floats) vectorValues, target, similarityFunction);
return new FloatVectorScorer((FloatVectorValues) vectorValues, target, similarityFunction);
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
assert vectorValues instanceof ByteVectorValues;
if (target.length != vectorValues.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
@ -77,8 +78,7 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
+ " differs from field dimension: "
+ vectorValues.dimension());
}
return new ByteVectorScorer(
(RandomAccessVectorValues.Bytes) vectorValues, target, similarityFunction);
return new ByteVectorScorer((ByteVectorValues) vectorValues, target, similarityFunction);
}
@Override
@ -88,14 +88,13 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
/** RandomVectorScorerSupplier for bytes vector */
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues.Bytes vectors;
private final RandomAccessVectorValues.Bytes vectors1;
private final RandomAccessVectorValues.Bytes vectors2;
private final ByteVectorValues vectors;
private final ByteVectorValues vectors1;
private final ByteVectorValues vectors2;
private final VectorSimilarityFunction similarityFunction;
private ByteScoringSupplier(
RandomAccessVectorValues.Bytes vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
@ -125,14 +124,13 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
/** RandomVectorScorerSupplier for Float vector */
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues.Floats vectors;
private final RandomAccessVectorValues.Floats vectors1;
private final RandomAccessVectorValues.Floats vectors2;
private final FloatVectorValues vectors;
private final FloatVectorValues vectors1;
private final FloatVectorValues vectors2;
private final VectorSimilarityFunction similarityFunction;
private FloatScoringSupplier(
RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction)
throws IOException {
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
@ -162,14 +160,12 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
/** A {@link RandomVectorScorer} for float vectors. */
private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final RandomAccessVectorValues.Floats values;
private final FloatVectorValues values;
private final float[] query;
private final VectorSimilarityFunction similarityFunction;
public FloatVectorScorer(
RandomAccessVectorValues.Floats values,
float[] query,
VectorSimilarityFunction similarityFunction) {
FloatVectorValues values, float[] query, VectorSimilarityFunction similarityFunction) {
super(values);
this.values = values;
this.query = query;
@ -184,14 +180,12 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
/** A {@link RandomVectorScorer} for byte vectors. */
private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final RandomAccessVectorValues.Bytes values;
private final ByteVectorValues values;
private final byte[] query;
private final VectorSimilarityFunction similarityFunction;
public ByteVectorScorer(
RandomAccessVectorValues.Bytes values,
byte[] query,
VectorSimilarityFunction similarityFunction) {
ByteVectorValues values, byte[] query, VectorSimilarityFunction similarityFunction) {
super(values);
this.values = values;
this.query = query;

View File

@ -18,8 +18,8 @@
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
@ -40,7 +40,19 @@ public interface FlatVectorsScorer {
* @throws IOException if an I/O error occurs
*/
RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException;
/**
* Returns a {@link RandomVectorScorer} for the given set of vectors and target vector.
*
* @param similarityFunction the similarity function to use
* @param vectorValues the vector values to score
* @param target the target vector
* @return a {@link RandomVectorScorer} for the given field and target vector.
* @throws IOException if an I/O error occurs when reading from the index.
*/
RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException;
/**
@ -53,23 +65,6 @@ public interface FlatVectorsScorer {
* @throws IOException if an I/O error occurs when reading from the index.
*/
RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
throws IOException;
/**
* Returns a {@link RandomVectorScorer} for the given set of vectors and target vector.
*
* @param similarityFunction the similarity function to use
* @param vectorValues the vector values to score
* @param target the target vector
* @return a {@link RandomVectorScorer} for the given field and target vector.
* @throws IOException if an I/O error occurs when reading from the index.
*/
RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException;
}

View File

@ -18,13 +18,13 @@
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;
@ -60,9 +60,9 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
return new ScalarQuantizedRandomVectorScorerSupplier(
similarityFunction,
quantizedByteVectorValues.getScalarQuantizer(),
@ -74,11 +74,9 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
byte[] targetBytes = new byte[target.length];
float offsetCorrection =
@ -104,9 +102,7 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
@ -124,14 +120,14 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
public static class ScalarQuantizedRandomVectorScorerSupplier
implements RandomVectorScorerSupplier {
private final RandomAccessQuantizedByteVectorValues values;
private final QuantizedByteVectorValues values;
private final ScalarQuantizedVectorSimilarity similarity;
private final VectorSimilarityFunction vectorSimilarityFunction;
public ScalarQuantizedRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values) {
QuantizedByteVectorValues values) {
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction,
@ -144,7 +140,7 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private ScalarQuantizedRandomVectorScorerSupplier(
ScalarQuantizedVectorSimilarity similarity,
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessQuantizedByteVectorValues values) {
QuantizedByteVectorValues values) {
this.similarity = similarity;
this.values = values;
this.vectorSimilarityFunction = vectorSimilarityFunction;
@ -152,7 +148,7 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
final QuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {

View File

@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene90;
import java.io.DataInput;
import java.io.IOException;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
@ -439,6 +440,40 @@ public final class IndexedDISI extends DocIdSetIterator {
// ALL variables
int gap;
/**
* Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance the
* underlying IndexedDISI, and vice-versa.
*/
public static KnnVectorValues.DocIndexIterator asDocIndexIterator(IndexedDISI disi) {
// can we replace with fromDISI?
return new KnnVectorValues.DocIndexIterator() {
@Override
public int docID() {
return disi.docID();
}
@Override
public int index() {
return disi.index();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return disi.advance(target);
}
@Override
public long cost() {
return disi.cost();
}
};
}
@Override
public int docID() {
return doc;

View File

@ -14,23 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.quantization;
package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.store.IndexInput;
/**
* Random access values for <code>byte[]</code>, but also includes accessing the score correction
* constant for the current vector in the buffer.
*
* @lucene.experimental
* Implementors can return the IndexInput from which their values are read. For use by vector
* quantizers.
*/
public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues.Bytes {
public interface HasIndexSlice {
ScalarQuantizer getScalarQuantizer();
float getScoreCorrectionConstant(int vectorOrd) throws IOException;
@Override
RandomAccessQuantizedByteVectorValues copy() throws IOException;
/** Returns an IndexInput from which to read this instance's values. */
IndexInput getSlice();
}

View File

@ -29,13 +29,11 @@ import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
public abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues.Bytes {
public abstract class OffHeapByteVectorValues extends ByteVectorValues implements HasIndexSlice {
protected final int dimension;
protected final int size;
@ -132,9 +130,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
* vector.
*/
public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension,
int size,
@ -145,36 +140,17 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction);
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
@ -183,17 +159,18 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public VectorScorer scorer(byte[] query) throws IOException {
DenseOffHeapVectorValues copy = copy();
DocIndexIterator iterator = copy.iterator();
RandomVectorScorer scorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return scorer.score(copy.doc);
return scorer.score(iterator.docID());
}
@Override
public DocIdSetIterator iterator() {
return copy;
return iterator;
}
};
}
@ -238,27 +215,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
configuration.size);
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
@ -276,6 +232,11 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
return (int) ordToDoc.get(ord);
}
@Override
public DocIndexIterator iterator() {
return IndexedDISI.asDocIndexIterator(disi);
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) {
@ -307,7 +268,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public DocIdSetIterator iterator() {
return copy;
return copy.disi;
}
};
}
@ -322,8 +283,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
@ -335,23 +294,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
}
@Override
public byte[] vectorValue() throws IOException {
public byte[] vectorValue(int ord) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
@ -359,11 +308,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
throw new UnsupportedOperationException();
}
@Override
public byte[] vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();

View File

@ -28,13 +28,11 @@ import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
public abstract class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
public abstract class OffHeapFloatVectorValues extends FloatVectorValues implements HasIndexSlice {
protected final int dimension;
protected final int size;
@ -128,8 +126,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
*/
public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension,
int size,
@ -140,55 +136,42 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction);
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
}
@Override
public int ordToDoc(int ord) {
return ord;
}
@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
DenseOffHeapVectorValues copy = copy();
DocIndexIterator iterator = copy.iterator();
RandomVectorScorer randomVectorScorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return randomVectorScorer.score(copy.doc);
return randomVectorScorer.score(iterator.docID());
}
@Override
public DocIdSetIterator iterator() {
return copy;
return iterator;
}
};
}
@ -227,27 +210,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
configuration.size);
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(
@ -283,20 +245,26 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
};
}
@Override
public DocIndexIterator iterator() {
return IndexedDISI.asDocIndexIterator(disi);
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
SparseOffHeapVectorValues copy = copy();
DocIndexIterator iterator = copy.iterator();
RandomVectorScorer randomVectorScorer =
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
return new VectorScorer() {
@Override
public float score() throws IOException {
return randomVectorScorer.score(copy.disi.index());
return randomVectorScorer.score(iterator.index());
}
@Override
public DocIdSetIterator iterator() {
return copy;
return iterator;
}
};
}
@ -311,8 +279,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
@ -323,26 +289,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
return 0;
}
@Override
public float[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) {
return doc = NO_MORE_DOCS;
}
@Override
public EmptyOffHeapVectorValues copy() {
throw new UnsupportedOperationException();
@ -354,8 +300,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override

View File

@ -39,6 +39,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
@ -361,11 +362,10 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
byte[] binaryValue = byteVectorValues.vectorValue();
byte[] binaryValue = byteVectorValues.vectorValue(iter.index());
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
@ -382,11 +382,10 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
ByteBuffer buffer =
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
float[] value = floatVectorValues.vectorValue();
float[] value = floatVectorValues.vectorValue(iter.index());
buffer.asFloatBuffer().put(value);
output.writeBytes(buffer.array(), buffer.limit());
docsWithField.add(docV);

View File

@ -32,14 +32,16 @@ import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
@ -54,7 +56,6 @@ import org.apache.lucene.util.hnsw.HnswGraphMerger;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
@ -359,18 +360,18 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
}
}
DocIdSetIterator mergedVectorIterator = null;
KnnVectorValues mergedVectorValues = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
mergedVectorIterator =
mergedVectorValues =
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 ->
mergedVectorIterator =
mergedVectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
graph =
merger.merge(
mergedVectorIterator,
mergedVectorValues,
segmentWriteState.infoStream,
scorerSupplier.totalVectorCount());
vectorIndexNodeOffsets = writeGraph(graph);
@ -582,13 +583,13 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
case BYTE ->
scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes(
ByteVectorValues.fromBytes(
(List<byte[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
case FLOAT32 ->
scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats(
FloatVectorValues.fromFloats(
(List<float[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
};

View File

@ -21,12 +21,12 @@ import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantize
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
@ -45,9 +45,9 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
return new ScalarQuantizedRandomVectorScorerSupplier(
quantizedByteVectorValues, similarityFunction);
}
@ -57,11 +57,9 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
byte[] targetBytes = new byte[target.length];
float offsetCorrection =
@ -79,9 +77,7 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target)
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
@ -96,7 +92,7 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
float offsetCorrection,
VectorSimilarityFunction sim,
float constMultiplier,
RandomAccessQuantizedByteVectorValues values) {
QuantizedByteVectorValues values) {
return switch (sim) {
case EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes);
case COSINE, DOT_PRODUCT ->
@ -120,7 +116,7 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
byte[] targetBytes,
float offsetCorrection,
float constMultiplier,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
FloatToFloatFunction scoreAdjustmentFunction) {
if (values.getScalarQuantizer().getBits() <= 4) {
if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) {
@ -137,10 +133,9 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer {
private final float constMultiplier;
private final byte[] targetBytes;
private final RandomAccessQuantizedByteVectorValues values;
private final QuantizedByteVectorValues values;
private Euclidean(
RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) {
private Euclidean(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) {
super(values);
this.values = values;
this.constMultiplier = constMultiplier;
@ -159,13 +154,13 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
/** Calculates dot product on quantized vectors, applying the appropriate corrections */
private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
private final float constMultiplier;
private final RandomAccessQuantizedByteVectorValues values;
private final QuantizedByteVectorValues values;
private final byte[] targetBytes;
private final float offsetCorrection;
private final FloatToFloatFunction scoreAdjustmentFunction;
public DotProduct(
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float constMultiplier,
byte[] targetBytes,
float offsetCorrection,
@ -193,14 +188,14 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private static class CompressedInt4DotProduct
extends RandomVectorScorer.AbstractRandomVectorScorer {
private final float constMultiplier;
private final RandomAccessQuantizedByteVectorValues values;
private final QuantizedByteVectorValues values;
private final byte[] compressedVector;
private final byte[] targetBytes;
private final float offsetCorrection;
private final FloatToFloatFunction scoreAdjustmentFunction;
private CompressedInt4DotProduct(
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float constMultiplier,
byte[] targetBytes,
float offsetCorrection,
@ -231,13 +226,13 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
private final float constMultiplier;
private final RandomAccessQuantizedByteVectorValues values;
private final QuantizedByteVectorValues values;
private final byte[] targetBytes;
private final float offsetCorrection;
private final FloatToFloatFunction scoreAdjustmentFunction;
public Int4DotProduct(
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float constMultiplier,
byte[] targetBytes,
float offsetCorrection,
@ -271,13 +266,12 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
implements RandomVectorScorerSupplier {
private final VectorSimilarityFunction vectorSimilarityFunction;
private final RandomAccessQuantizedByteVectorValues values;
private final RandomAccessQuantizedByteVectorValues values1;
private final RandomAccessQuantizedByteVectorValues values2;
private final QuantizedByteVectorValues values;
private final QuantizedByteVectorValues values1;
private final QuantizedByteVectorValues values2;
public ScalarQuantizedRandomVectorScorerSupplier(
RandomAccessQuantizedByteVectorValues values,
VectorSimilarityFunction vectorSimilarityFunction)
QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction)
throws IOException {
this.values = values;
this.values1 = values.copy();

View File

@ -402,10 +402,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
private static final class QuantizedVectorValues extends FloatVectorValues {
private final FloatVectorValues rawVectorValues;
private final OffHeapQuantizedByteVectorValues quantizedVectorValues;
private final QuantizedByteVectorValues quantizedVectorValues;
QuantizedVectorValues(
FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) {
FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) {
this.rawVectorValues = rawVectorValues;
this.quantizedVectorValues = quantizedVectorValues;
}
@ -421,34 +421,28 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
}
@Override
public float[] vectorValue() throws IOException {
return rawVectorValues.vectorValue();
public float[] vectorValue(int ord) throws IOException {
return rawVectorValues.vectorValue(ord);
}
@Override
public int docID() {
return rawVectorValues.docID();
public int ordToDoc(int ord) {
return rawVectorValues.ordToDoc(ord);
}
@Override
public int nextDoc() throws IOException {
int rawDocId = rawVectorValues.nextDoc();
int quantizedDocId = quantizedVectorValues.nextDoc();
assert rawDocId == quantizedDocId;
return quantizedDocId;
}
@Override
public int advance(int target) throws IOException {
int rawDocId = rawVectorValues.advance(target);
int quantizedDocId = quantizedVectorValues.advance(target);
assert rawDocId == quantizedDocId;
return quantizedDocId;
public QuantizedVectorValues copy() throws IOException {
return new QuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
return quantizedVectorValues.scorer(query);
}
@Override
public DocIndexIterator iterator() {
return rawVectorValues.iterator();
}
}
}

View File

@ -19,9 +19,7 @@ package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.*;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
@ -45,6 +43,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
@ -653,12 +652,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|| bits <= 4
|| shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
int numVectors = 0;
FloatVectorValues vectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
DocIdSetIterator iter =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)
.iterator();
// iterate vectorValues and increment numVectors
for (int doc = vectorValues.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = vectorValues.nextDoc()) {
for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) {
numVectors++;
}
return buildScalarQuantizer(
@ -730,11 +728,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
? OffHeapQuantizedByteVectorValues.compressedArray(
quantizedByteVectorValues.dimension(), bits)
: null;
for (int docV = quantizedByteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = quantizedByteVectorValues.nextDoc()) {
KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator();
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
// write vector
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
byte[] binaryValue = quantizedByteVectorValues.vectorValue(iter.index());
assert binaryValue.length == quantizedByteVectorValues.dimension()
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
if (compressedVector != null) {
@ -743,7 +740,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
} else {
output.writeBytes(binaryValue, binaryValue.length);
}
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
output.writeInt(
Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(iter.index())));
docsWithField.add(docV);
}
return docsWithField;
@ -855,7 +853,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
static class FloatVectorWrapper extends FloatVectorValues {
private final List<float[]> vectorList;
protected int curDoc = -1;
FloatVectorWrapper(List<float[]> vectorList) {
this.vectorList = vectorList;
@ -872,51 +869,42 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= vectorList.size()) {
throw new IOException("Current doc not set or too many iterations");
}
return vectorList.get(curDoc);
public FloatVectorValues copy() throws IOException {
return this;
}
@Override
public int docID() {
if (curDoc >= vectorList.size()) {
return NO_MORE_DOCS;
public float[] vectorValue(int ord) throws IOException {
if (ord < 0 || ord >= vectorList.size()) {
throw new IOException("vector ord " + ord + " out of bounds");
}
return curDoc;
return vectorList.get(ord);
}
@Override
public int nextDoc() throws IOException {
curDoc++;
return docID();
}
@Override
public int advance(int target) throws IOException {
curDoc = target;
return docID();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return createDenseIterator();
}
}
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
private final QuantizedByteVectorValues values;
private final KnnVectorValues.DocIndexIterator iterator;
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
iterator = values.iterator();
assert iterator.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
return values.nextDoc();
return iterator.nextDoc();
}
public int index() {
return iterator.index();
}
}
@ -973,7 +961,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final DocIDMerger<QuantizedByteVectorValueSub> docIdMerger;
private final int size;
private int docId;
private QuantizedByteVectorValueSub current;
private MergedQuantizedVectorValues(
@ -985,33 +972,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override
public byte[] vectorValue() throws IOException {
return current.values.vectorValue();
public byte[] vectorValue(int ord) throws IOException {
return current.values.vectorValue(current.index());
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
}
return docId;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return new CompositeIterator();
}
@Override
@ -1025,20 +995,59 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public float getScoreCorrectionConstant() throws IOException {
return current.values.getScoreCorrectionConstant();
public float getScoreCorrectionConstant(int ord) throws IOException {
return current.values.getScoreCorrectionConstant(current.index());
}
private class CompositeIterator extends DocIndexIterator {
private int docId;
private int ord;
public CompositeIterator() {
docId = -1;
ord = -1;
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
public int index() {
return ord;
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
ord = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
++ord;
}
return docId;
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return size;
}
}
}
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
private final FloatVectorValues values;
private final ScalarQuantizer quantizer;
private final byte[] quantizedVector;
private int lastOrd = -1;
private float offsetValue = 0f;
private final VectorSimilarityFunction vectorSimilarityFunction;
@ -1054,7 +1063,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public float getScoreCorrectionConstant() {
public float getScoreCorrectionConstant(int ord) {
if (ord != lastOrd) {
throw new IllegalStateException(
"attempt to retrieve score correction for different ord "
+ ord
+ " than the quantization was done for: "
+ lastOrd);
}
return offsetValue;
}
@ -1069,41 +1085,31 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public byte[] vectorValue() throws IOException {
public byte[] vectorValue(int ord) throws IOException {
if (ord != lastOrd) {
offsetValue = quantize(ord);
lastOrd = ord;
}
return quantizedVector;
}
@Override
public int docID() {
return values.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = values.nextDoc();
if (doc != NO_MORE_DOCS) {
quantize();
}
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = values.advance(target);
if (doc != NO_MORE_DOCS) {
quantize();
}
return doc;
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
private void quantize() throws IOException {
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
private float quantize(int ord) throws IOException {
return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction);
}
@Override
public int ordToDoc(int ord) {
return values.ordToDoc(ord);
}
@Override
public DocIndexIterator iterator() {
return values.iterator();
}
}
@ -1160,9 +1166,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public float getScoreCorrectionConstant() throws IOException {
public float getScoreCorrectionConstant(int ord) throws IOException {
return scalarQuantizer.recalculateCorrectiveOffset(
in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction);
in.vectorValue(ord), oldScalarQuantizer, vectorSimilarityFunction);
}
@Override
@ -1176,35 +1182,24 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public byte[] vectorValue() throws IOException {
return in.vectorValue();
public byte[] vectorValue(int ord) throws IOException {
return in.vectorValue(ord);
}
@Override
public int docID() {
return in.docID();
public int ordToDoc(int ord) {
return in.ordToDoc(ord);
}
@Override
public int nextDoc() throws IOException {
return in.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return in.advance(target);
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return in.iterator();
}
}
static final class NormalizedFloatVectorValues extends FloatVectorValues {
private final FloatVectorValues values;
private final float[] normalizedVector;
int curDoc = -1;
public NormalizedFloatVectorValues(FloatVectorValues values) {
this.values = values;
@ -1222,38 +1217,25 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
@Override
public float[] vectorValue() throws IOException {
public int ordToDoc(int ord) {
return values.ordToDoc(ord);
}
@Override
public float[] vectorValue(int ord) throws IOException {
System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
return normalizedVector;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
throw new UnsupportedOperationException();
public DocIndexIterator iterator() {
return values.iterator();
}
@Override
public int docID() {
return values.docID();
}
@Override
public int nextDoc() throws IOException {
curDoc = values.nextDoc();
if (curDoc != NO_MORE_DOCS) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
}
return curDoc;
}
@Override
public int advance(int target) throws IOException {
curDoc = values.advance(target);
if (curDoc != NO_MORE_DOCS) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
}
return curDoc;
public NormalizedFloatVectorValues copy() throws IOException {
return new NormalizedFloatVectorValues(values.copy());
}
}
}

View File

@ -30,15 +30,13 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
* Read the quantized vector values and their score correction values from the index input. This
* supports both iterated and random access.
*/
public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues
implements RandomAccessQuantizedByteVectorValues {
public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues {
protected final int dimension;
protected final int size;
@ -141,11 +139,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
return binaryValue;
}
@Override
public float getScoreCorrectionConstant() {
return scoreCorrectionConstant[0];
}
@Override
public float getScoreCorrectionConstant(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
@ -213,8 +206,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
*/
public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(
int dimension,
int size,
@ -226,30 +217,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(
@ -270,20 +237,26 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public VectorScorer scorer(float[] target) throws IOException {
DenseOffHeapVectorValues copy = copy();
DocIndexIterator iterator = copy.iterator();
RandomVectorScorer vectorScorer =
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorScorer.score(copy.doc);
return vectorScorer.score(iterator.index());
}
@Override
public DocIdSetIterator iterator() {
return copy;
return iterator;
}
};
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
}
private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
@ -312,24 +285,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
public DocIndexIterator iterator() {
return IndexedDISI.asDocIndexIterator(disi);
}
@Override
@ -372,17 +329,18 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
@Override
public VectorScorer scorer(float[] target) throws IOException {
SparseOffHeapVectorValues copy = copy();
DocIndexIterator iterator = copy.iterator();
RandomVectorScorer vectorScorer =
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
return new VectorScorer() {
@Override
public float score() throws IOException {
return vectorScorer.score(copy.disi.index());
return vectorScorer.score(iterator.index());
}
@Override
public DocIdSetIterator iterator() {
return copy;
return iterator;
}
};
}
@ -404,8 +362,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
null);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
@ -417,23 +373,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
}
@Override
public byte[] vectorValue() {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) {
return doc = NO_MORE_DOCS;
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override

View File

@ -17,8 +17,8 @@
package org.apache.lucene.index;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
/**
@ -27,34 +27,21 @@ import org.apache.lucene.search.VectorScorer;
*
* @lucene.experimental
*/
public abstract class ByteVectorValues extends DocIdSetIterator {
public abstract class ByteVectorValues extends KnnVectorValues {
/** Sole constructor */
protected ByteVectorValues() {}
/** Return the dimension of the vectors */
public abstract int dimension();
/**
* Return the number of vectors for this field.
*
* @return the number of vectors returned by this iterator
*/
public abstract int size();
@Override
public final long cost() {
return size();
}
/**
* Return the vector value for the current document ID. It is illegal to call this method when the
* iterator is not positioned: before advancing, or after failing to advance. The returned array
* may be shared across calls, re-used, and modified as the iterator advances.
* Return the vector value for the given vector ordinal which must be in [0, size() - 1],
* otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls.
*
* @return the vector value
*/
public abstract byte[] vectorValue() throws IOException;
public abstract byte[] vectorValue(int ord) throws IOException;
@Override
public abstract ByteVectorValues copy() throws IOException;
/**
* Checks the Vector Encoding of a field
@ -78,12 +65,53 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
}
/**
* Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not
* the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and
* iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}.
* Return a {@link VectorScorer} for the given query vector.
*
* @param query the query vector
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer scorer(byte[] query) throws IOException;
public VectorScorer scorer(byte[] query) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorEncoding getEncoding() {
return VectorEncoding.BYTE;
}
/**
* Creates a {@link ByteVectorValues} from a list of byte arrays.
*
* @param vectors the list of byte arrays
* @param dim the dimension of the vectors
* @return a {@link ByteVectorValues} instancec
*/
public static ByteVectorValues fromBytes(List<byte[]> vectors, int dim) {
return new ByteVectorValues() {
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public byte[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public ByteVectorValues copy() {
return this;
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
};
}
}

View File

@ -2760,16 +2760,16 @@ public final class CheckIndex implements Closeable {
CheckIndex.Status.VectorValuesStatus status,
CodecReader codecReader)
throws IOException {
int docCount = 0;
int count = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
while (values.nextDoc() != NO_MORE_DOCS) {
while (count < values.size()) {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
if (values.ordToDoc(count) % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) {
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(), collector, null);
.search(fieldInfo.name, values.vectorValue(count), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
@ -2777,7 +2777,7 @@ public final class CheckIndex implements Closeable {
}
}
}
int valueLength = values.vectorValue().length;
int valueLength = values.vectorValue(count).length;
if (valueLength != fieldInfo.getVectorDimension()) {
throw new CheckIndexException(
"Field \""
@ -2787,19 +2787,19 @@ public final class CheckIndex implements Closeable {
+ " not matching the field's dimension="
+ fieldInfo.getVectorDimension());
}
++docCount;
++count;
}
if (docCount != values.size()) {
if (count != values.size()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has size="
+ values.size()
+ " but when iterated, returns "
+ docCount
+ count
+ " docs with values");
}
status.totalVectorValues += docCount;
status.totalVectorValues += count;
}
private static void checkByteVectorValues(
@ -2808,21 +2808,23 @@ public final class CheckIndex implements Closeable {
CheckIndex.Status.VectorValuesStatus status,
CodecReader codecReader)
throws IOException {
int docCount = 0;
int count = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name);
while (values.nextDoc() != NO_MORE_DOCS) {
while (count < values.size()) {
// search the first maxNumSearches vectors to exercise the graph
if (supportsSearch && values.docID() % everyNdoc == 0) {
if (supportsSearch && values.ordToDoc(count) % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(count), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
int valueLength = values.vectorValue().length;
int valueLength = values.vectorValue(count).length;
if (valueLength != fieldInfo.getVectorDimension()) {
throw new CheckIndexException(
"Field \""
@ -2832,19 +2834,19 @@ public final class CheckIndex implements Closeable {
+ " not matching the field's dimension="
+ fieldInfo.getVectorDimension());
}
++docCount;
++count;
}
if (docCount != values.size()) {
if (count != values.size()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has size="
+ values.size()
+ " but when iterated, returns "
+ docCount
+ count
+ " docs with values");
}
status.totalVectorValues += docCount;
status.totalVectorValues += count;
}
/**

View File

@ -429,37 +429,10 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
}
private class ExitableFloatVectorValues extends FloatVectorValues {
private int docToCheck;
private final FloatVectorValues vectorValues;
public ExitableFloatVectorValues(FloatVectorValues vectorValues) {
this.vectorValues = vectorValues;
docToCheck = 0;
}
@Override
public int advance(int target) throws IOException {
final int advance = vectorValues.advance(target);
if (advance >= docToCheck) {
checkAndThrow();
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public int docID() {
return vectorValues.docID();
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = vectorValues.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow();
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
@Override
@ -468,8 +441,13 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
}
@Override
public float[] vectorValue() throws IOException {
return vectorValues.vectorValue();
public float[] vectorValue(int ord) throws IOException {
return vectorValues.vectorValue(ord);
}
@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}
@Override
@ -477,61 +455,27 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return vectorValues.size();
}
@Override
public DocIndexIterator iterator() {
return createExitableIterator(vectorValues.iterator(), queryTimeout);
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
return vectorValues.scorer(target);
}
/**
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
* if {@link Thread#interrupted()} returns true.
*/
private void checkAndThrow() {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException(
"The request took too long to iterate over vector values. Timeout: "
+ queryTimeout.toString()
+ ", FloatVectorValues="
+ in);
} else if (Thread.interrupted()) {
throw new ExitingReaderException(
"Interrupted while iterating over vector values. FloatVectorValues=" + in);
}
@Override
public FloatVectorValues copy() {
throw new UnsupportedOperationException();
}
}
private class ExitableByteVectorValues extends ByteVectorValues {
private int docToCheck;
private final ByteVectorValues vectorValues;
public ExitableByteVectorValues(ByteVectorValues vectorValues) {
this.vectorValues = vectorValues;
docToCheck = 0;
}
@Override
public int advance(int target) throws IOException {
final int advance = vectorValues.advance(target);
if (advance >= docToCheck) {
checkAndThrow();
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public int docID() {
return vectorValues.docID();
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = vectorValues.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow();
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
@Override
@ -545,8 +489,18 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValues.vectorValue();
public byte[] vectorValue(int ord) throws IOException {
return vectorValues.vectorValue(ord);
}
@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}
@Override
public DocIndexIterator iterator() {
return createExitableIterator(vectorValues.iterator(), queryTimeout);
}
@Override
@ -554,23 +508,66 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return vectorValues.scorer(target);
}
/**
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
* if {@link Thread#interrupted()} returns true.
*/
@Override
public ByteVectorValues copy() {
throw new UnsupportedOperationException();
}
}
}
private static KnnVectorValues.DocIndexIterator createExitableIterator(
KnnVectorValues.DocIndexIterator delegate, QueryTimeout queryTimeout) {
return new KnnVectorValues.DocIndexIterator() {
private int nextCheck;
@Override
public int index() {
return delegate.index();
}
@Override
public int docID() {
return delegate.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = delegate.nextDoc();
if (doc >= nextCheck) {
checkAndThrow();
nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK;
}
return doc;
}
@Override
public long cost() {
return delegate.cost();
}
@Override
public int advance(int target) throws IOException {
int doc = delegate.advance(target);
if (doc >= nextCheck) {
checkAndThrow();
nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK;
}
return doc;
}
private void checkAndThrow() {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException(
"The request took too long to iterate over vector values. Timeout: "
"The request took too long to iterate over knn vector values. Timeout: "
+ queryTimeout.toString()
+ ", ByteVectorValues="
+ in);
+ ", KnnVectorValues="
+ delegate);
} else if (Thread.interrupted()) {
throw new ExitingReaderException(
"Interrupted while iterating over vector values. ByteVectorValues=" + in);
}
"Interrupted while iterating over knn vector values. KnnVectorValues=" + delegate);
}
}
};
}
/** Wrapper class for another PointValues implementation that is used by ExitableFields. */
@ -683,7 +680,7 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException(
"The request took too long to intersect point values. Timeout: "
+ queryTimeout.toString()
+ queryTimeout
+ ", PointValues="
+ pointValues);
} else if (Thread.interrupted()) {
@ -815,7 +812,7 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
/** Wrapper class for another Terms implementation that is used by ExitableFields. */
public static class ExitableTerms extends FilterTerms {
private QueryTimeout queryTimeout;
private final QueryTimeout queryTimeout;
/** Constructor * */
public ExitableTerms(Terms terms, QueryTimeout queryTimeout) {

View File

@ -17,8 +17,8 @@
package org.apache.lucene.index;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
/**
@ -27,34 +27,21 @@ import org.apache.lucene.search.VectorScorer;
*
* @lucene.experimental
*/
public abstract class FloatVectorValues extends DocIdSetIterator {
public abstract class FloatVectorValues extends KnnVectorValues {
/** Sole constructor */
protected FloatVectorValues() {}
/** Return the dimension of the vectors */
public abstract int dimension();
/**
* Return the number of vectors for this field.
*
* @return the number of vectors returned by this iterator
*/
public abstract int size();
@Override
public final long cost() {
return size();
}
/**
* Return the vector value for the current document ID. It is illegal to call this method when the
* iterator is not positioned: before advancing, or after failing to advance. The returned array
* may be shared across calls, re-used, and modified as the iterator advances.
* Return the vector value for the given vector ordinal which must be in [0, size() - 1],
* otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls.
*
* @return the vector value
*/
public abstract float[] vectorValue() throws IOException;
public abstract float[] vectorValue(int ord) throws IOException;
@Override
public abstract FloatVectorValues copy() throws IOException;
/**
* Checks the Vector Encoding of a field
@ -79,12 +66,53 @@ public abstract class FloatVectorValues extends DocIdSetIterator {
/**
* Return a {@link VectorScorer} for the given query vector and the current {@link
* FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for
* this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the
* iteration of this {@link FloatVectorValues}.
* FloatVectorValues}.
*
* @param query the query vector
* @param target the query vector
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer scorer(float[] query) throws IOException;
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorEncoding getEncoding() {
return VectorEncoding.FLOAT32;
}
/**
* Creates a {@link FloatVectorValues} from a list of float arrays.
*
* @param vectors the list of float arrays
* @param dim the dimension of the vectors
* @return a {@link FloatVectorValues} instance
*/
public static FloatVectorValues fromFloats(List<float[]> vectors, int dim) {
return new FloatVectorValues() {
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public FloatVectorValues copy() {
return this;
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
};
}
}

View File

@ -0,0 +1,229 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
/**
* This class abstracts addressing of document vector values indexed as {@link KnnFloatVectorField}
* or {@link KnnByteVectorField}.
*
* @lucene.experimental
*/
public abstract class KnnVectorValues {
/** Return the dimension of the vectors */
public abstract int dimension();
/**
* Return the number of vectors for this field.
*
* @return the number of vectors returned by this iterator
*/
public abstract int size();
/**
* Return the docid of the document indexed with the given vector ordinal. This default
* implementation returns the argument and is appropriate for dense values implementations where
* every doc has a single value.
*/
public int ordToDoc(int ord) {
return ord;
}
/**
* Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access
* different values at once, to avoid overwriting the underlying vector returned.
*/
public abstract KnnVectorValues copy() throws IOException;
/** Returns the vector byte length, defaults to dimension multiplied by float byte size */
public int getVectorByteLength() {
return dimension() * getEncoding().byteSize;
}
/** The vector encoding of these values. */
public abstract VectorEncoding getEncoding();
/** Returns a Bits accepting docs accepted by the argument and having a vector value */
public Bits getAcceptOrds(Bits acceptDocs) {
// FIXME: change default to return acceptDocs and provide this impl
// somewhere more specialized (in every non-dense impl).
if (acceptDocs == null) {
return null;
}
return new Bits() {
@Override
public boolean get(int index) {
return acceptDocs.get(ordToDoc(index));
}
@Override
public int length() {
return size();
}
};
}
/** Create an iterator for this instance. */
public DocIndexIterator iterator() {
throw new UnsupportedOperationException();
}
/**
* A DocIdSetIterator that also provides an index() method tracking a distinct ordinal for a
* vector associated with each doc.
*/
public abstract static class DocIndexIterator extends DocIdSetIterator {
/** return the value index (aka "ordinal" or "ord") corresponding to the current doc */
public abstract int index();
}
/**
* Creates an iterator for instances where every doc has a value, and the value ordinals are equal
* to the docids.
*/
protected DocIndexIterator createDenseIterator() {
return new DocIndexIterator() {
int doc = -1;
@Override
public int docID() {
return doc;
}
@Override
public int index() {
return doc;
}
@Override
public int nextDoc() throws IOException {
if (doc >= size() - 1) {
return doc = NO_MORE_DOCS;
} else {
return ++doc;
}
}
@Override
public int advance(int target) {
if (target >= size()) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public long cost() {
return size();
}
};
}
/**
* Creates an iterator from a DocIdSetIterator indicating which docs have values, and for which
* ordinals increase monotonically with docid.
*/
protected static DocIndexIterator fromDISI(DocIdSetIterator docsWithField) {
return new DocIndexIterator() {
int ord = -1;
@Override
public int docID() {
return docsWithField.docID();
}
@Override
public int index() {
return ord;
}
@Override
public int nextDoc() throws IOException {
if (docID() == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
ord++;
return docsWithField.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return docsWithField.advance(target);
}
@Override
public long cost() {
return docsWithField.cost();
}
};
}
/**
* Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic
* (docid increases when ordinal does).
*/
protected DocIndexIterator createSparseIterator() {
return new DocIndexIterator() {
private int ord = -1;
@Override
public int docID() {
if (ord == -1) {
return -1;
}
if (ord == NO_MORE_DOCS) {
return NO_MORE_DOCS;
}
return ordToDoc(ord);
}
@Override
public int index() {
return ord;
}
@Override
public int nextDoc() throws IOException {
if (ord >= size() - 1) {
ord = NO_MORE_DOCS;
} else {
++ord;
}
return docID();
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return size();
}
};
}
}

View File

@ -34,9 +34,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues;
import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
@ -303,38 +301,21 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
}
}
private record DocValuesSub<T extends DocIdSetIterator>(T sub, int docStart, int docEnd) {}
private record DocValuesSub<T extends KnnVectorValues>(T sub, int docStart, int ordStart) {}
private static class MergedDocIdSetIterator<T extends DocIdSetIterator> extends DocIdSetIterator {
private static class MergedDocIterator<T extends KnnVectorValues>
extends KnnVectorValues.DocIndexIterator {
final Iterator<DocValuesSub<T>> it;
final long cost;
DocValuesSub<T> current;
int currentIndex = 0;
KnnVectorValues.DocIndexIterator currentIterator;
int ord = -1;
int doc = -1;
MergedDocIdSetIterator(List<DocValuesSub<T>> subs) {
long cost = 0;
for (DocValuesSub<T> sub : subs) {
if (sub.sub != null) {
cost += sub.sub.cost();
}
}
this.cost = cost;
MergedDocIterator(List<DocValuesSub<T>> subs) {
this.it = subs.iterator();
current = it.next();
}
private boolean advanceSub(int target) {
while (current.sub == null || current.docEnd <= target) {
if (it.hasNext() == false) {
doc = NO_MORE_DOCS;
return false;
}
current = it.next();
currentIndex++;
}
return true;
currentIterator = currentIterator();
}
@Override
@ -342,41 +323,47 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
return doc;
}
@Override
public int index() {
return ord;
}
@Override
public int nextDoc() throws IOException {
while (true) {
if (current.sub != null) {
int next = current.sub.nextDoc();
int next = currentIterator.nextDoc();
if (next != NO_MORE_DOCS) {
++ord;
return doc = current.docStart + next;
}
}
if (it.hasNext() == false) {
ord = NO_MORE_DOCS;
return doc = NO_MORE_DOCS;
}
current = it.next();
currentIndex++;
currentIterator = currentIterator();
ord = current.ordStart - 1;
}
}
@Override
public int advance(int target) throws IOException {
while (true) {
if (advanceSub(target) == false) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int next = current.sub.advance(target - current.docStart);
if (next == DocIdSetIterator.NO_MORE_DOCS) {
target = current.docEnd;
private KnnVectorValues.DocIndexIterator currentIterator() {
if (current.sub != null) {
return current.sub.iterator();
} else {
return doc = current.docStart + next;
}
return null;
}
}
@Override
public long cost() {
return cost;
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
}
@ -848,55 +835,75 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
int size = 0;
for (CodecReader reader : codecReaders) {
FloatVectorValues values = reader.getFloatVectorValues(field);
subs.add(new DocValuesSub<>(values, docStarts[i], size));
if (values != null) {
if (dimension == -1) {
dimension = values.dimension();
}
size += values.size();
}
subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1]));
i++;
}
final int finalDimension = dimension;
final int finalSize = size;
MergedDocIdSetIterator<FloatVectorValues> mergedIterator = new MergedDocIdSetIterator<>(subs);
return new FloatVectorValues() {
return new MergedFloatVectorValues(dimension, size, subs);
}
class MergedFloatVectorValues extends FloatVectorValues {
final int dimension;
final int size;
final DocValuesSub<?>[] subs;
final MergedDocIterator<FloatVectorValues> iter;
final int[] starts;
int lastSubIndex;
MergedFloatVectorValues(int dimension, int size, List<DocValuesSub<FloatVectorValues>> subs) {
this.dimension = dimension;
this.size = size;
this.subs = subs.toArray(new DocValuesSub<?>[0]);
iter = new MergedDocIterator<>(subs);
// [0, start(1), ..., size] - we want the extra element
// to avoid checking for out-of-array bounds
starts = new int[subs.size() + 1];
for (int i = 0; i < subs.size(); i++) {
starts[i] = subs.get(i).ordStart;
}
starts[starts.length - 1] = size;
}
@Override
public MergedDocIterator<FloatVectorValues> iterator() {
return iter;
}
@Override
public int dimension() {
return finalDimension;
return dimension;
}
@Override
public int size() {
return finalSize;
return size;
}
@SuppressWarnings("unchecked")
@Override
public FloatVectorValues copy() throws IOException {
List<DocValuesSub<FloatVectorValues>> subsCopy = new ArrayList<>();
for (Object sub : subs) {
subsCopy.add((DocValuesSub<FloatVectorValues>) sub);
}
return new MergedFloatVectorValues(dimension, size, subsCopy);
}
@Override
public float[] vectorValue() throws IOException {
return mergedIterator.current.sub.vectorValue();
public float[] vectorValue(int ord) throws IOException {
assert ord >= 0 && ord < size;
// We need to implement fully random-access API here in order to support callers like
// SortingCodecReader that
// rely on it.
lastSubIndex = findSub(ord, lastSubIndex, starts);
return ((FloatVectorValues) subs[lastSubIndex].sub)
.vectorValue(ord - subs[lastSubIndex].ordStart);
}
@Override
public int docID() {
return mergedIterator.docID();
}
@Override
public int nextDoc() throws IOException {
return mergedIterator.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return mergedIterator.advance(target);
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
};
}
@Override
@ -907,55 +914,96 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
int size = 0;
for (CodecReader reader : codecReaders) {
ByteVectorValues values = reader.getByteVectorValues(field);
subs.add(new DocValuesSub<>(values, docStarts[i], size));
if (values != null) {
if (dimension == -1) {
dimension = values.dimension();
}
size += values.size();
}
subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1]));
i++;
}
final int finalDimension = dimension;
final int finalSize = size;
MergedDocIdSetIterator<ByteVectorValues> mergedIterator = new MergedDocIdSetIterator<>(subs);
return new ByteVectorValues() {
return new MergedByteVectorValues(dimension, size, subs);
}
class MergedByteVectorValues extends ByteVectorValues {
final int dimension;
final int size;
final DocValuesSub<?>[] subs;
final MergedDocIterator<ByteVectorValues> iter;
final int[] starts;
int lastSubIndex;
MergedByteVectorValues(int dimension, int size, List<DocValuesSub<ByteVectorValues>> subs) {
this.dimension = dimension;
this.size = size;
this.subs = subs.toArray(new DocValuesSub<?>[0]);
iter = new MergedDocIterator<>(subs);
// [0, start(1), ..., size] - we want the extra element
// to avoid checking for out-of-array bounds
starts = new int[subs.size() + 1];
for (int i = 0; i < subs.size(); i++) {
starts[i] = subs.get(i).ordStart;
}
starts[starts.length - 1] = size;
}
@Override
public MergedDocIterator<ByteVectorValues> iterator() {
return iter;
}
@Override
public int dimension() {
return finalDimension;
return dimension;
}
@Override
public int size() {
return finalSize;
return size;
}
@Override
public byte[] vectorValue() throws IOException {
return mergedIterator.current.sub.vectorValue();
public byte[] vectorValue(int ord) throws IOException {
assert ord >= 0 && ord < size;
// We need to implement fully random-access API here in order to support callers like
// SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some
// repetition.
lastSubIndex = findSub(ord, lastSubIndex, starts);
return ((ByteVectorValues) subs[lastSubIndex].sub)
.vectorValue(ord - subs[lastSubIndex].ordStart);
}
@SuppressWarnings("unchecked")
@Override
public int docID() {
return mergedIterator.docID();
public ByteVectorValues copy() throws IOException {
List<DocValuesSub<ByteVectorValues>> newSubs = new ArrayList<>();
for (Object sub : subs) {
newSubs.add((DocValuesSub<ByteVectorValues>) sub);
}
return new MergedByteVectorValues(dimension, size, newSubs);
}
}
@Override
public int nextDoc() throws IOException {
return mergedIterator.nextDoc();
private static int findSub(int ord, int lastSubIndex, int[] starts) {
if (ord >= starts[lastSubIndex]) {
if (ord >= starts[lastSubIndex + 1]) {
return binarySearchStarts(starts, ord, lastSubIndex + 1, starts.length);
}
} else {
return binarySearchStarts(starts, ord, 0, lastSubIndex);
}
return lastSubIndex;
}
@Override
public int advance(int target) throws IOException {
return mergedIterator.advance(target);
private static int binarySearchStarts(int[] starts, int ord, int from, int to) {
int pos = Arrays.binarySearch(starts, from, to, ord);
if (pos < 0) {
// subtract one since binarySearch returns an *insertion point*
return -2 - pos;
} else {
return pos;
}
@Override
public VectorScorer scorer(byte[] target) {
throw new UnsupportedOperationException();
}
};
}
@Override

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
@ -32,10 +33,11 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOSupplier;
@ -206,121 +208,175 @@ public final class SortingCodecReader extends FilterCodecReader {
}
}
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
private static class SortingFloatVectorValues extends FloatVectorValues {
final int size;
final int dimension;
final FixedBitSet docsWithField;
final float[][] vectors;
/**
* Factory for SortingValuesIterator. This enables us to create new iterators as needed without
* recomputing the sorting mappings.
*/
static class SortingIteratorSupplier implements Supplier<SortingValuesIterator> {
private final FixedBitSet docBits;
private final int[] docToOrd;
private final int size;
private int docId = -1;
SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.size = delegate.size();
this.dimension = delegate.dimension();
docsWithField = new FixedBitSet(sortMap.size());
vectors = new float[sortMap.size()][];
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID);
vectors[newDocID] = delegate.vectorValue().clone();
SortingIteratorSupplier(FixedBitSet docBits, int[] docToOrd, int size) {
this.docBits = docBits;
this.docToOrd = docToOrd;
this.size = size;
}
@Override
public SortingValuesIterator get() {
return new SortingValuesIterator(docBits, docToOrd, size);
}
public int size() {
return size;
}
}
/**
* Creates a factory for SortingValuesIterator. Does the work of computing the (new docId to old
* ordinal) mapping, and caches the result, enabling it to create new iterators cheaply.
*
* @param values the values over which to iterate
* @param docMap the mapping from "old" docIds to "new" (sorted) docIds.
*/
public static SortingIteratorSupplier iteratorSupplier(
KnnVectorValues values, Sorter.DocMap docMap) throws IOException {
final int[] docToOrd = new int[docMap.size()];
final FixedBitSet docBits = new FixedBitSet(docMap.size());
int count = 0;
// Note: docToOrd will contain zero for docids that have no vector. This is OK though
// because the iterator cannot be positioned on such docs
KnnVectorValues.DocIndexIterator iter = values.iterator();
for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) {
int newDocId = docMap.oldToNew(doc);
if (newDocId != -1) {
docToOrd[newDocId] = iter.index();
docBits.set(newDocId);
++count;
}
}
return new SortingIteratorSupplier(docBits, docToOrd, count);
}
/**
* Iterator over KnnVectorValues accepting a mapping to differently-sorted docs. Consequently
* index() may skip around, not increasing monotonically as iteration proceeds.
*/
public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator {
private final FixedBitSet docBits;
private final DocIdSetIterator docsWithValues;
private final int[] docToOrd;
int doc = -1;
SortingValuesIterator(FixedBitSet docBits, int[] docToOrd, int size) {
this.docBits = docBits;
this.docToOrd = docToOrd;
docsWithValues = new BitSetIterator(docBits, size);
}
@Override
public int docID() {
return docId;
return doc;
}
@Override
public int index() {
assert docBits.get(doc);
return docToOrd[doc];
}
@Override
public int nextDoc() throws IOException {
return advance(docId + 1);
if (doc != NO_MORE_DOCS) {
doc = docsWithValues.nextDoc();
}
return doc;
}
@Override
public float[] vectorValue() throws IOException {
return vectors[docId];
public long cost() {
return docBits.cardinality();
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return size;
}
@Override
public int advance(int target) throws IOException {
if (target >= docsWithField.length()) {
return NO_MORE_DOCS;
}
return docId = docsWithField.nextSetBit(target);
}
@Override
public VectorScorer scorer(float[] target) {
public int advance(int target) {
throw new UnsupportedOperationException();
}
}
private static class SortingByteVectorValues extends ByteVectorValues {
final int size;
final int dimension;
final FixedBitSet docsWithField;
final byte[][] vectors;
/** Sorting FloatVectorValues that maps ordinals using the provided sortMap */
private static class SortingFloatVectorValues extends FloatVectorValues {
final FloatVectorValues delegate;
final SortingIteratorSupplier iteratorSupplier;
private int docId = -1;
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.size = delegate.size();
this.dimension = delegate.dimension();
docsWithField = new FixedBitSet(sortMap.size());
vectors = new byte[sortMap.size()][];
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID);
vectors[newDocID] = delegate.vectorValue().clone();
}
SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
iteratorSupplier = iteratorSupplier(delegate, sortMap);
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
return advance(docId + 1);
}
@Override
public byte[] vectorValue() throws IOException {
return vectors[docId];
public float[] vectorValue(int ord) throws IOException {
// ords are interpreted in the delegate's ord-space.
return delegate.vectorValue(ord);
}
@Override
public int dimension() {
return dimension;
return delegate.dimension();
}
@Override
public int size() {
return size;
return iteratorSupplier.size();
}
@Override
public int advance(int target) throws IOException {
if (target >= docsWithField.length()) {
return NO_MORE_DOCS;
}
return docId = docsWithField.nextSetBit(target);
public FloatVectorValues copy() {
throw new UnsupportedOperationException();
}
@Override
public VectorScorer scorer(byte[] target) {
public DocIndexIterator iterator() {
return iteratorSupplier.get();
}
}
private static class SortingByteVectorValues extends ByteVectorValues {
final ByteVectorValues delegate;
final SortingIteratorSupplier iteratorSupplier;
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
iteratorSupplier = iteratorSupplier(delegate, sortMap);
}
@Override
public byte[] vectorValue(int ord) throws IOException {
return delegate.vectorValue(ord);
}
@Override
public DocIndexIterator iterator() {
return iteratorSupplier.get();
}
@Override
public int dimension() {
return delegate.dimension();
}
@Override
public int size() {
return iteratorSupplier.size();
}
@Override
public ByteVectorValues copy() {
throw new UnsupportedOperationException();
}
}

View File

@ -181,8 +181,8 @@ public class FieldExistsQuery extends Query {
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
iterator =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> context.reader().getFloatVectorValues(field);
case BYTE -> context.reader().getByteVectorValues(field);
case FLOAT32 -> context.reader().getFloatVectorValues(field).iterator();
case BYTE -> context.reader().getByteVectorValues(field).iterator();
};
} else if (fieldInfo.getDocValuesType()
!= DocValuesType.NONE) { // the field indexes doc values

View File

@ -19,7 +19,7 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
@ -46,7 +46,7 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
}
@Override
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd)
throws IOException {
if (initReader == null) {
return new HnswConcurrentMergeBuilder(
@ -61,7 +61,7 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
return new HnswConcurrentMergeBuilder(
taskExecutor,

View File

@ -18,8 +18,8 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.InfoStream;
@ -45,12 +45,12 @@ public interface HnswGraphMerger {
/**
* Merge and produce the on heap graph
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @param mergedVectorValues view of the vectors in the merged segment
* @param infoStream optional info stream to set to builder
* @param maxOrd max number of vectors that will be added to the graph
* @return merged graph
* @throws IOException during merge
*/
OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd)
OnHeapHnswGraph merge(KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd)
throws IOException;
}

View File

@ -25,9 +25,9 @@ import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
@ -108,12 +108,12 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
* Builds a new HnswGraphBuilder using the biggest graph from the merge state as a starting point.
* If no valid readers were added to the merge state, a new graph is created.
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @param mergedVectorValues vector values in the merged segment
* @param maxOrd max num of vectors that will be merged into the graph
* @return HnswGraphBuilder
* @throws IOException If an error occurs while reading from the merge state
*/
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd)
throws IOException {
if (initReader == null) {
return HnswGraphBuilder.create(
@ -123,7 +123,7 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
return InitializedHnswGraphBuilder.fromGraph(
scorerSupplier,
M,
@ -137,8 +137,8 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
@Override
public OnHeapHnswGraph merge(
DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException {
HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd);
KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException {
HnswBuilder builder = createBuilder(mergedVectorValues, maxOrd);
builder.setInfoStream(infoStream);
return builder.build(maxOrd);
}
@ -147,46 +147,45 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
* Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors
* in the newly merged segment.
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
* @param mergedVectorValues vector values in the merged segment
* @param initializedNodes track what nodes have been initialized
* @return the mapping from old ordinals to new ordinals
* @throws IOException If an error occurs while reading from the merge state
*/
protected final int[] getNewOrdMapping(
DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException {
DocIdSetIterator initializerIterator = null;
KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException {
KnnVectorValues.DocIndexIterator initializerIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name);
case FLOAT32 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name);
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator();
case FLOAT32 ->
initializerIterator = initReader.getFloatVectorValues(fieldInfo.name).iterator();
}
IntIntHashMap newIdToOldOrdinal = new IntIntHashMap(initGraphSize);
int oldOrd = 0;
int maxNewDocID = -1;
for (int oldId = initializerIterator.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerIterator.nextDoc()) {
int newId = initDocMap.get(oldId);
for (int docId = initializerIterator.nextDoc();
docId != NO_MORE_DOCS;
docId = initializerIterator.nextDoc()) {
int newId = initDocMap.get(docId);
maxNewDocID = Math.max(newId, maxNewDocID);
newIdToOldOrdinal.put(newId, oldOrd);
oldOrd++;
newIdToOldOrdinal.put(newId, initializerIterator.index());
}
if (maxNewDocID == -1) {
return new int[0];
}
final int[] oldToNewOrdinalMap = new int[initGraphSize];
int newOrd = 0;
KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator();
for (int newDocId = mergedVectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = mergedVectorIterator.nextDoc()) {
int hashDocIndex = newIdToOldOrdinal.indexOf(newDocId);
if (newIdToOldOrdinal.indexExists(hashDocIndex)) {
int newOrd = mergedVectorIterator.index();
initializedNodes.set(newOrd);
oldToNewOrdinalMap[newIdToOldOrdinal.indexGet(hashDocIndex)] = newOrd;
}
newOrd++;
}
return oldToNewOrdinalMap;
}

View File

@ -1,175 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
/**
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
* implementations of KNN search.
*
* @lucene.experimental
*/
public interface RandomAccessVectorValues {
/** Return the number of vector values */
int size();
/** Return the dimension of the returned vector values */
int dimension();
/**
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
* access different values at once, to avoid overwriting the underlying vector returned.
*/
RandomAccessVectorValues copy() throws IOException;
/**
* Returns a slice of the underlying {@link IndexInput} that contains the vector values if
* available
*/
default IndexInput getSlice() {
return null;
}
/** Returns the byte length of the vector values. */
int getVectorByteLength();
/**
* Translates vector ordinal to the correct document ID. By default, this is an identity function.
*
* @param ord the vector ordinal
* @return the document Id for that vector ordinal
*/
default int ordToDoc(int ord) {
return ord;
}
/**
* Returns the {@link Bits} representing live documents. By default, this is an identity function.
*
* @param acceptDocs the accept docs
* @return the accept docs
*/
default Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
/** Float vector values. */
interface Floats extends RandomAccessVectorValues {
@Override
RandomAccessVectorValues.Floats copy() throws IOException;
/**
* Return the vector value indexed at the given ordinal.
*
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
float[] vectorValue(int targetOrd) throws IOException;
/** Returns the vector byte length, defaults to dimension multiplied by float byte size */
@Override
default int getVectorByteLength() {
return dimension() * Float.BYTES;
}
}
/** Byte vector values. */
interface Bytes extends RandomAccessVectorValues {
@Override
RandomAccessVectorValues.Bytes copy() throws IOException;
/**
* Return the vector value indexed at the given ordinal.
*
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
byte[] vectorValue(int targetOrd) throws IOException;
/** Returns the vector byte length, defaults to dimension multiplied by byte size */
@Override
default int getVectorByteLength() {
return dimension() * Byte.BYTES;
}
}
/**
* Creates a {@link RandomAccessVectorValues.Floats} from a list of float arrays.
*
* @param vectors the list of float arrays
* @param dim the dimension of the vectors
* @return a {@link RandomAccessVectorValues.Floats} instance
*/
static RandomAccessVectorValues.Floats fromFloats(List<float[]> vectors, int dim) {
return new RandomAccessVectorValues.Floats() {
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues.Floats copy() {
return this;
}
};
}
/**
* Creates a {@link RandomAccessVectorValues.Bytes} from a list of byte arrays.
*
* @param vectors the list of byte arrays
* @param dim the dimension of the vectors
* @return a {@link RandomAccessVectorValues.Bytes} instance
*/
static RandomAccessVectorValues.Bytes fromBytes(List<byte[]> vectors, int dim) {
return new RandomAccessVectorValues.Bytes() {
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public byte[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues.Bytes copy() {
return this;
}
};
}
}

View File

@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.util.Bits;
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
@ -57,14 +58,14 @@ public interface RandomVectorScorer {
/** Creates a default scorer for random access vectors. */
abstract class AbstractRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues values;
private final KnnVectorValues values;
/**
* Creates a new scorer for the given vector values.
*
* @param values the vector values
*/
public AbstractRandomVectorScorer(RandomAccessVectorValues values) {
public AbstractRandomVectorScorer(KnnVectorValues values) {
this.values = values;
}

View File

@ -17,9 +17,10 @@
package org.apache.lucene.util.quantization;
import java.io.IOException;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
/**
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
@ -27,31 +28,31 @@ import org.apache.lucene.search.VectorScorer;
*
* @lucene.experimental
*/
public abstract class QuantizedByteVectorValues extends DocIdSetIterator {
public abstract float getScoreCorrectionConstant() throws IOException;
public abstract class QuantizedByteVectorValues extends ByteVectorValues implements HasIndexSlice {
public abstract byte[] vectorValue() throws IOException;
/** Return the dimension of the vectors */
public abstract int dimension();
/**
* Return the number of vectors for this field.
*
* @return the number of vectors returned by this iterator
*/
public abstract int size();
@Override
public final long cost() {
return size();
public ScalarQuantizer getScalarQuantizer() {
throw new UnsupportedOperationException();
}
public abstract float getScoreCorrectionConstant(int ord) throws IOException;
/**
* Return a {@link VectorScorer} for the given query vector.
*
* @param query the query vector
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer scorer(float[] query) throws IOException;
public VectorScorer scorer(float[] query) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public QuantizedByteVectorValues copy() throws IOException {
return this;
}
@Override
public IndexInput getSlice() {
return null;
}
}

View File

@ -25,6 +25,7 @@ import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
@ -269,11 +270,12 @@ public class ScalarQuantizer {
if (totalVectorCount == 0) {
return new ScalarQuantizer(0f, 0f, bits);
}
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
if (confidenceInterval == 1f) {
float min = Float.POSITIVE_INFINITY;
float max = Float.NEGATIVE_INFINITY;
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
for (float v : floatVectorValues.vectorValue()) {
while (iterator.nextDoc() != NO_MORE_DOCS) {
for (float v : floatVectorValues.vectorValue(iterator.index())) {
min = Math.min(min, v);
max = Math.max(max, v);
}
@ -289,8 +291,8 @@ public class ScalarQuantizer {
if (totalVectorCount <= quantizationSampleSize) {
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
int i = 0;
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
float[] vectorValue = floatVectorValues.vectorValue();
while (iterator.nextDoc() != NO_MORE_DOCS) {
float[] vectorValue = floatVectorValues.vectorValue(iterator.index());
System.arraycopy(
vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
i++;
@ -311,11 +313,11 @@ public class ScalarQuantizer {
for (int i : vectorsToTake) {
while (index <= i) {
// We cannot use `advance(docId)` as MergedVectorValues does not support it
floatVectorValues.nextDoc();
iterator.nextDoc();
index++;
}
assert floatVectorValues.docID() != NO_MORE_DOCS;
float[] vectorValue = floatVectorValues.vectorValue();
assert iterator.docID() != NO_MORE_DOCS;
float[] vectorValue = floatVectorValues.vectorValue(iterator.index());
System.arraycopy(
vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length);
idx++;
@ -353,11 +355,16 @@ public class ScalarQuantizer {
/ (floatVectorValues.dimension() + 1),
1 - 1f / (floatVectorValues.dimension() + 1)
};
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
if (totalVectorCount <= sampleSize) {
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
int i = 0;
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i);
while (iterator.nextDoc() != NO_MORE_DOCS) {
gatherSample(
floatVectorValues.vectorValue(iterator.index()),
quantileGatheringScratch,
sampledDocs,
i);
i++;
if (i == scratchSize) {
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
@ -374,11 +381,15 @@ public class ScalarQuantizer {
for (int i : vectorsToTake) {
while (index <= i) {
// We cannot use `advance(docId)` as MergedVectorValues does not support it
floatVectorValues.nextDoc();
iterator.nextDoc();
index++;
}
assert floatVectorValues.docID() != NO_MORE_DOCS;
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx);
assert iterator.docID() != NO_MORE_DOCS;
gatherSample(
floatVectorValues.vectorValue(iterator.index()),
quantileGatheringScratch,
sampledDocs,
idx);
idx++;
if (idx == SCRATCH_SIZE) {
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
@ -437,12 +448,7 @@ public class ScalarQuantizer {
}
private static void gatherSample(
FloatVectorValues floatVectorValues,
float[] quantileGatheringScratch,
List<float[]> sampledDocs,
int i)
throws IOException {
float[] vectorValue = floatVectorValues.vectorValue();
float[] vectorValue, float[] quantileGatheringScratch, List<float[]> sampledDocs, int i) {
float[] copy = new float[vectorValue.length];
System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length);
sampledDocs.add(copy);

View File

@ -19,11 +19,11 @@ package org.apache.lucene.internal.vectorization;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
abstract sealed class Lucene99MemorySegmentByteVectorScorer
@ -39,10 +39,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
* returned.
*/
public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
VectorSimilarityFunction type,
IndexInput input,
RandomAccessVectorValues values,
byte[] queryVector) {
VectorSimilarityFunction type, IndexInput input, KnnVectorValues values, byte[] queryVector) {
input = FilterIndexInput.unwrapOnlyTest(input);
if (!(input instanceof MemorySegmentAccessInput msInput)) {
return Optional.empty();
@ -58,7 +55,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
}
Lucene99MemorySegmentByteVectorScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) {
MemorySegmentAccessInput input, KnnVectorValues values, byte[] queryVector) {
super(values);
this.input = input;
this.vectorByteSize = values.getVectorByteLength();
@ -92,7 +89,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
}
static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer {
CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
CosineScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
super(input, values, query);
}
@ -105,8 +102,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
}
static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer {
DotProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
DotProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
super(input, values, query);
}
@ -120,7 +116,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
}
static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer {
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
EuclideanScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
super(input, values, query);
}
@ -133,8 +129,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
}
static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer {
MaxInnerProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
MaxInnerProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
super(input, values, query);
}

View File

@ -19,11 +19,11 @@ package org.apache.lucene.internal.vectorization;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
@ -33,7 +33,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
final int vectorByteSize;
final int maxOrd;
final MemorySegmentAccessInput input;
final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds
final KnnVectorValues values; // to support ordToDoc/getAcceptOrds
byte[] scratch1, scratch2;
/**
@ -41,7 +41,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
* optional is returned.
*/
static Optional<RandomVectorScorerSupplier> create(
VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) {
VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) {
input = FilterIndexInput.unwrapOnlyTest(input);
if (!(input instanceof MemorySegmentAccessInput msInput)) {
return Optional.empty();
@ -56,7 +56,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
}
Lucene99MemorySegmentByteVectorScorerSupplier(
MemorySegmentAccessInput input, RandomAccessVectorValues values) {
MemorySegmentAccessInput input, KnnVectorValues values) {
this.input = input;
this.values = values;
this.vectorByteSize = values.getVectorByteLength();
@ -103,7 +103,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
CosineSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
super(input, values);
}
@ -128,7 +128,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
DotProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
super(input, values);
}
@ -155,7 +155,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
EuclideanSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
super(input, values);
}
@ -181,7 +181,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
MaxInnerProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
super(input, values);
}

View File

@ -19,11 +19,12 @@ package org.apache.lucene.internal.vectorization;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer {
@ -38,15 +39,15 @@ public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues)
throws IOException {
VectorSimilarityFunction similarityType, KnnVectorValues vectorValues) throws IOException {
// a quantized values here is a wrapping or delegation issue
assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues);
assert !(vectorValues instanceof QuantizedByteVectorValues);
// currently only supports binary vectors
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
if (vectorValues instanceof HasIndexSlice byteVectorValues
&& byteVectorValues.getSlice() != null) {
var scorer =
Lucene99MemorySegmentByteVectorScorerSupplier.create(
similarityType, vectorValues.getSlice(), vectorValues);
similarityType, byteVectorValues.getSlice(), vectorValues);
if (scorer.isPresent()) {
return scorer.get();
}
@ -56,9 +57,7 @@ public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityType,
RandomAccessVectorValues vectorValues,
float[] target)
VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, float[] target)
throws IOException {
// currently only supports binary vectors, so always delegate
return delegate.getRandomVectorScorer(similarityType, vectorValues, target);
@ -66,17 +65,16 @@ public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityType,
RandomAccessVectorValues vectorValues,
byte[] queryVector)
VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, byte[] queryVector)
throws IOException {
checkDimensions(queryVector.length, vectorValues.dimension());
// a quantized values here is a wrapping or delegation issue
assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues);
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
assert !(vectorValues instanceof QuantizedByteVectorValues);
if (vectorValues instanceof HasIndexSlice byteVectorValues
&& byteVectorValues.getSlice() != null) {
var scorer =
Lucene99MemorySegmentByteVectorScorer.create(
similarityType, vectorValues.getSlice(), vectorValues, queryVector);
similarityType, byteVectorValues.getSlice(), vectorValues, queryVector);
if (scorer.isPresent()) {
return scorer.get();
}

View File

@ -35,6 +35,8 @@ import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
@ -42,7 +44,6 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
@ -174,13 +175,13 @@ public class TestFlatVectorScorer extends LuceneTestCase {
}
}
RandomAccessVectorValues byteVectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim)
throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("byteValues", 0, in.length()), dims, flatVectorsScorer, sim);
}
RandomAccessVectorValues floatVectorValues(
FloatVectorValues floatVectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
dims,

View File

@ -17,7 +17,6 @@
package org.apache.lucene.codecs.lucene99;
import static java.lang.String.format;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;
@ -312,14 +311,13 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
assertNotNull(hnswReader.getQuantizationState("f"));
QuantizedByteVectorValues quantizedByteVectorValues =
hnswReader.getQuantizedVectorValues("f");
int docId = -1;
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] vector = quantizedByteVectorValues.vectorValue();
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) {
byte[] vector = quantizedByteVectorValues.vectorValue(ord);
float offset = quantizedByteVectorValues.getScoreCorrectionConstant(ord);
for (int i = 0; i < dim; i++) {
assertEquals(vector[i], expectedVectors[docId][i]);
assertEquals(vector[i], expectedVectors[ord][i]);
}
assertEquals(offset, expectedCorrections[docId], 0.00001f);
assertEquals(offset, expectedCorrections[ord], 0.00001f);
}
} else {
fail("reader is not Lucene99HnswVectorsReader");

View File

@ -46,7 +46,7 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
@ -100,8 +100,8 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
Lucene99ScalarQuantizedVectorScorer scorer =
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
RandomAccessQuantizedByteVectorValues values =
new RandomAccessQuantizedByteVectorValues() {
QuantizedByteVectorValues values =
new QuantizedByteVectorValues() {
@Override
public int dimension() {
return 32;
@ -128,7 +128,7 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
}
@Override
public RandomAccessQuantizedByteVectorValues copy() throws IOException {
public QuantizedByteVectorValues copy() throws IOException {
return this;
}

View File

@ -37,6 +37,7 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -173,9 +174,10 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
QuantizedByteVectorValues quantizedByteVectorValues =
quantizedReader.getQuantizedVectorValues("f");
int docId = -1;
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] vector = quantizedByteVectorValues.vectorValue();
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator();
for (docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) {
byte[] vector = quantizedByteVectorValues.vectorValue(iter.index());
float offset = quantizedByteVectorValues.getScoreCorrectionConstant(iter.index());
for (int i = 0; i < dim; i++) {
assertEquals(vector[i], expectedVectors[docId][i]);
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.document;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import org.apache.lucene.codecs.Codec;
@ -27,6 +28,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
@ -713,17 +715,21 @@ public class TestField extends LuceneTestCase {
try (IndexReader r = DirectoryReader.open(w)) {
ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary");
assertEquals(1, binary.size());
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
assertNotNull(binary.vectorValue());
assertArrayEquals(b, binary.vectorValue());
assertEquals(NO_MORE_DOCS, binary.nextDoc());
KnnVectorValues.DocIndexIterator iterator = binary.iterator();
assertNotEquals(NO_MORE_DOCS, iterator.nextDoc());
assertNotNull(binary.vectorValue(0));
assertArrayEquals(b, binary.vectorValue(0));
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
expectThrows(IOException.class, () -> binary.vectorValue(1));
FloatVectorValues floatValues = r.leaves().get(0).reader().getFloatVectorValues("float");
assertEquals(1, floatValues.size());
assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc());
assertEquals(vector.length, floatValues.vectorValue().length);
assertEquals(vector[0], floatValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, floatValues.nextDoc());
KnnVectorValues.DocIndexIterator iterator1 = floatValues.iterator();
assertNotEquals(NO_MORE_DOCS, iterator1.nextDoc());
assertEquals(vector.length, floatValues.vectorValue(0).length);
assertEquals(vector[0], floatValues.vectorValue(0)[0], 0);
assertEquals(NO_MORE_DOCS, iterator1.nextDoc());
expectThrows(IOException.class, () -> floatValues.vectorValue(1));
}
}
}

View File

@ -459,8 +459,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
expectThrows(
ExitingReaderException.class,
() -> {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);
KnnVectorValues values = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, values);
});
expectThrows(
@ -473,8 +473,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);
KnnVectorValues values = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, values);
leaf.searchNearestVectors(
"vector",
@ -534,8 +534,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
expectThrows(
ExitingReaderException.class,
() -> {
DocIdSetIterator iter = leaf.getByteVectorValues("vector");
scanAndRetrieve(leaf, iter);
KnnVectorValues values = leaf.getByteVectorValues("vector");
scanAndRetrieve(leaf, values);
});
expectThrows(
@ -549,8 +549,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getByteVectorValues("vector");
scanAndRetrieve(leaf, iter);
KnnVectorValues values = leaf.getByteVectorValues("vector");
scanAndRetrieve(leaf, values);
leaf.searchNearestVectors(
"vector",
@ -564,20 +564,24 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
directory.close();
}
private static void scanAndRetrieve(LeafReader leaf, DocIdSetIterator iter) throws IOException {
private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) throws IOException {
KnnVectorValues.DocIndexIterator iter = values.iterator();
for (iter.nextDoc();
iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) {
final int nextDocId = iter.docID() + 1;
int docId = iter.docID();
if (docId >= leaf.maxDoc()) {
break;
}
final int nextDocId = docId + 1;
if (random().nextBoolean() && nextDocId < leaf.maxDoc()) {
iter.advance(nextDocId);
} else {
iter.nextDoc();
}
if (random().nextBoolean()
&& iter.docID() != DocIdSetIterator.NO_MORE_DOCS
&& iter instanceof FloatVectorValues) {
((FloatVectorValues) iter).vectorValue();
&& values instanceof FloatVectorValues) {
((FloatVectorValues) values).vectorValue(iter.index());
}
}
}

View File

@ -413,11 +413,13 @@ public class TestKnnGraph extends LuceneTestCase {
// stored vector values are the same as original
int nextDocWithVectors = 0;
StoredFields storedFields = reader.storedFields();
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
for (int i = 0; i < reader.maxDoc(); i++) {
nextDocWithVectors = vectorValues.advance(i);
nextDocWithVectors = iterator.advance(i);
while (i < nextDocWithVectors && i < reader.maxDoc()) {
int id = Integer.parseInt(storedFields.document(i).get("id"));
assertNull("document " + id + " has no vector, but was expected to", values[id]);
assertNull(
"document " + id + ", expected to have no vector, does have one", values[id]);
++i;
}
if (nextDocWithVectors == NO_MORE_DOCS) {
@ -425,7 +427,7 @@ public class TestKnnGraph extends LuceneTestCase {
}
int id = Integer.parseInt(storedFields.document(i).get("id"));
// documents with KnnGraphValues have the expected vectors
float[] scratch = vectorValues.vectorValue();
float[] scratch = vectorValues.vectorValue(iterator.index());
assertArrayEquals(
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
values[id],
@ -435,9 +437,9 @@ public class TestKnnGraph extends LuceneTestCase {
}
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()
if (nextDocWithVectors != NO_MORE_DOCS) {
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
} else {
assertEquals(NO_MORE_DOCS, vectorValues.docID());
assertEquals(NO_MORE_DOCS, iterator.docID());
}
// assert graph values:

View File

@ -242,6 +242,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
NumericDocValues ids = leaf.getNumericDocValues("id");
long prevValue = -1;
boolean usingAltIds = false;
KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator();
for (int i = 0; i < actualNumDocs; i++) {
int idNext = ids.nextDoc();
if (idNext == DocIdSetIterator.NO_MORE_DOCS) {
@ -262,7 +263,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
assertTrue(sorted_numeric_dv.advanceExact(idNext));
assertTrue(sorted_set_dv.advanceExact(idNext));
assertTrue(binary_sorted_dv.advanceExact(idNext));
assertEquals(idNext, vectorValues.advance(idNext));
assertEquals(idNext, valuesIterator.advance(idNext));
assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue());
assertEquals(
new BytesRef(ids.longValue() + ""),
@ -274,7 +275,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
assertEquals(1, sorted_numeric_dv.docValueCount());
assertEquals(ids.longValue(), sorted_numeric_dv.nextValue());
float[] vectorValue = vectorValues.vectorValue();
float[] vectorValue = vectorValues.vectorValue(valuesIterator.index());
assertEquals(1, vectorValue.length);
assertEquals((float) ids.longValue(), vectorValue[0], 0.001f);

View File

@ -39,6 +39,7 @@ import java.util.stream.IntStream;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
@ -47,7 +48,6 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.junit.BeforeClass;
@ -329,8 +329,8 @@ public class TestVectorScorer extends LuceneTestCase {
}
}
RandomAccessVectorValues vectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim)
throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim);
}

View File

@ -38,6 +38,7 @@ import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
@ -740,7 +741,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
LeafReader leafReader = getOnlyLeafReader(reader);
FieldInfo fi = leafReader.getFieldInfos().fieldInfo("field");
assertNotNull(fi);
DocIdSetIterator vectorValues;
KnnVectorValues vectorValues;
switch (fi.getVectorEncoding()) {
case BYTE:
vectorValues = leafReader.getByteVectorValues("field");
@ -752,7 +753,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
throw new AssertionError();
}
assertNotNull(vectorValues);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc());
}
}
}

View File

@ -113,7 +113,7 @@ public class TestTimeLimitingBulkScorer extends LuceneTestCase {
private static QueryTimeout countingQueryTimeout(int timeallowed) {
return new QueryTimeout() {
static int counter = 0;
int counter = 0;
@Override
public boolean shouldExit() {

View File

@ -1,89 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.util.BytesRef;
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues {
protected final int dimension;
protected final T[] denseValues;
protected final T[] values;
protected final int numVectors;
protected final BytesRef binaryValue;
protected int pos = -1;
AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) {
this.dimension = dimension;
this.values = values;
this.denseValues = denseValues;
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension);
binaryValue.length = dimension;
this.numVectors = numVectors;
}
@Override
public int size() {
return numVectors;
}
@Override
public int dimension() {
return dimension;
}
public T vectorValue(int targetOrd) {
return denseValues[targetOrd];
}
@Override
public abstract AbstractMockVectorValues<T> copy();
public abstract T vectorValue() throws IOException;
private boolean seek(int target) {
if (target >= 0 && target < values.length && values[target] != null) {
pos = target;
return true;
} else {
return false;
}
}
public int docID() {
return pos;
}
public int nextDoc() {
return advance(pos + 1);
}
public int advance(int target) {
while (++pos < values.length) {
if (seek(pos)) {
return pos;
}
}
return NO_MORE_DOCS;
}
}

View File

@ -56,6 +56,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
@ -97,33 +98,28 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
abstract T randomVector(int dim);
abstract AbstractMockVectorValues<T> vectorValues(int size, int dimension);
abstract KnnVectorValues vectorValues(int size, int dimension);
abstract AbstractMockVectorValues<T> vectorValues(float[][] values);
abstract KnnVectorValues vectorValues(float[][] values);
abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
throws IOException;
abstract KnnVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException;
abstract AbstractMockVectorValues<T> vectorValues(
int size,
int dimension,
AbstractMockVectorValues<T> pregeneratedVectorValues,
int pregeneratedOffset);
abstract KnnVectorValues vectorValues(
int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset);
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
abstract RandomAccessVectorValues circularVectorValues(int nDoc);
abstract KnnVectorValues circularVectorValues(int nDoc);
abstract T getTargetVector();
protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors)
protected RandomVectorScorerSupplier buildScorerSupplier(KnnVectorValues vectors)
throws IOException {
return flatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectors);
}
protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query)
throws IOException {
RandomAccessVectorValues vectorsCopy = vectors.copy();
protected RandomVectorScorer buildScorer(KnnVectorValues vectors, T query) throws IOException {
KnnVectorValues vectorsCopy = vectors.copy();
return switch (getVectorEncoding()) {
case BYTE ->
flatVectorScorer.getRandomVectorScorer(similarityFunction, vectorsCopy, (byte[]) query);
@ -134,6 +130,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
// Tests writing segments of various sizes and merging to ensure there are no errors
// in the HNSW graph merging logic.
@SuppressWarnings("unchecked")
public void testRandomReadWriteAndMerge() throws IOException {
int dim = random().nextInt(100) + 1;
int[] segmentSizes =
@ -148,7 +145,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int M = random().nextInt(4) + 2;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
AbstractMockVectorValues<T> vectors = vectorValues(numVectors, dim);
KnnVectorValues vectors = vectorValues(numVectors, dim);
HnswGraphBuilder.randSeed = seed;
try (Directory dir = newDirectory()) {
@ -173,7 +170,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
for (int i = 0; i < segmentSizes.length; i++) {
int size = segmentSizes[i];
while (vectors.nextDoc() < size) {
for (int ord = 0; ord < size; ord++) {
if (isSparse[i] && random().nextBoolean()) {
int d = random().nextInt(10) + 1;
for (int j = 0; j < d; j++) {
@ -182,8 +179,24 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
}
Document doc = new Document();
doc.add(knnVectorField("field", vectors.vectorValue(), similarityFunction));
doc.add(new StringField("id", Integer.toString(vectors.docID()), Field.Store.NO));
switch (vectors.getEncoding()) {
case BYTE -> {
doc.add(
knnVectorField(
"field",
(T) ((ByteVectorValues) vectors).vectorValue(ord),
similarityFunction));
}
case FLOAT32 -> {
doc.add(
knnVectorField(
"field",
(T) ((FloatVectorValues) vectors).vectorValue(ord),
similarityFunction));
}
}
;
doc.add(new StringField("id", Integer.toString(vectors.ordToDoc(ord)), Field.Store.NO));
iw.addDocument(doc);
}
iw.commit();
@ -199,13 +212,26 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
KnnVectorValues values = vectorValues(ctx.reader(), "field");
assertEquals(dim, values.dimension());
}
}
}
}
@SuppressWarnings("unchecked")
private T vectorValue(KnnVectorValues vectors, int ord) throws IOException {
switch (vectors.getEncoding()) {
case BYTE -> {
return (T) ((ByteVectorValues) vectors).vectorValue(ord);
}
case FLOAT32 -> {
return (T) ((FloatVectorValues) vectors).vectorValue(ord);
}
}
throw new AssertionError("unknown encoding " + vectors.getEncoding());
}
// test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1;
@ -213,8 +239,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int M = random().nextInt(4) + 2;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
KnnVectorValues vectors = vectorValues(nDoc, dim);
KnnVectorValues v2 = vectors.copy(), v3 = vectors.copy();
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.size());
@ -242,15 +268,16 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
});
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
while (v2.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < v2.docID()) {
KnnVectorValues.DocIndexIterator it2 = v2.iterator();
while (it2.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < it2.docID()) {
// increment docId in the index by adding empty documents
iw.addDocument(new Document());
indexedDoc++;
}
Document doc = new Document();
doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction));
doc.add(new StoredField("id", v2.docID()));
doc.add(knnVectorField("field", vectorValue(v2, it2.index()), similarityFunction));
doc.add(new StoredField("id", it2.docID()));
iw.addDocument(doc);
nVec++;
indexedDoc++;
@ -258,7 +285,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
KnnVectorValues values = vectorValues(ctx.reader(), "field");
assertEquals(dim, values.dimension());
assertEquals(nVec, values.size());
assertEquals(indexedDoc, ctx.reader().maxDoc());
@ -280,7 +307,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3;
int nDoc = random().nextInt(200) + 100;
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
KnnVectorValues vectors = vectorValues(nDoc, dim);
int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 10;
@ -323,15 +350,15 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int indexedDoc = 0;
try (IndexWriter iw = new IndexWriter(dir, iwc);
IndexWriter iw2 = new IndexWriter(dir2, iwc2)) {
while (vectors.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < vectors.docID()) {
for (int ord = 0; ord < vectors.size(); ord++) {
while (indexedDoc < vectors.ordToDoc(ord)) {
// increment docId in the index by adding empty documents
iw.addDocument(new Document());
indexedDoc++;
}
Document doc = new Document();
doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction));
doc.add(new StoredField("id", vectors.docID()));
doc.add(knnVectorField("vector", vectorValue(vectors, ord), similarityFunction));
doc.add(new StoredField("id", vectors.ordToDoc(ord)));
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
iw.addDocument(doc);
iw2.addDocument(doc);
@ -461,7 +488,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testAknnDiverse() throws IOException {
int nDoc = 100;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
KnnVectorValues vectors = circularVectorValues(nDoc);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
@ -493,7 +520,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
@SuppressWarnings("unchecked")
public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100;
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
KnnVectorValues vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
@ -518,7 +545,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
@SuppressWarnings("unchecked")
public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100;
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
KnnVectorValues vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
@ -552,13 +579,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int dim = atLeast(10);
long seed = random().nextLong();
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
KnnVectorValues initializerVectors = vectorValues(initializerSize, dim);
RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors);
HnswGraphBuilder initializerBuilder =
HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed);
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
AbstractMockVectorValues<T> finalVectorValues =
KnnVectorValues finalVectorValues =
vectorValues(totalSize, dim, initializerVectors, docIdOffset);
int[] initializerOrdMap =
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
@ -598,13 +625,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int dim = atLeast(10);
long seed = random().nextLong();
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
KnnVectorValues initializerVectors = vectorValues(initializerSize, dim);
RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors);
HnswGraphBuilder initializerBuilder =
HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed);
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
AbstractMockVectorValues<T> finalVectorValues =
KnnVectorValues finalVectorValues =
vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
int[] initializerOrdMap =
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
@ -688,19 +715,17 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
private int[] createOffsetOrdinalMap(
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
int docIdSize, KnnVectorValues totalVectorValues, int docIdOffset) throws IOException {
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
// vector values
// before the docIdOffset
// vector values before the docIdOffset
int ordinalOffset = 0;
while (totalVectorValues.nextDoc() < docIdOffset) {
KnnVectorValues.DocIndexIterator it = totalVectorValues.iterator();
while (it.nextDoc() < docIdOffset) {
ordinalOffset++;
}
int[] offsetOrdinalMap = new int[docIdSize];
for (int curr = 0;
totalVectorValues.docID() < docIdOffset + docIdSize;
totalVectorValues.nextDoc()) {
for (int curr = 0; it.docID() < docIdOffset + docIdSize; it.nextDoc()) {
offsetOrdinalMap[curr] = ordinalOffset + curr++;
}
@ -711,7 +736,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testVisitedLimit() throws IOException {
int nDoc = 500;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
KnnVectorValues vectors = circularVectorValues(nDoc);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
@ -746,7 +771,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int M = randomIntBetween(4, 96);
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
RandomAccessVectorValues vectors = vectorValues(size, dim);
KnnVectorValues vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder =
@ -771,7 +796,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
unitVector2d(0.77),
unitVector2d(0.6)
};
AbstractMockVectorValues<T> vectors = vectorValues(values);
KnnVectorValues vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 2, 10, random().nextInt());
@ -825,7 +850,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
{10, 0, 0},
{0, 4, 0}
};
AbstractMockVectorValues<T> vectors = vectorValues(values);
KnnVectorValues vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt());
@ -855,7 +880,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
{0, 0, 20},
{0, 9, 0}
};
AbstractMockVectorValues<T> vectors = vectorValues(values);
KnnVectorValues vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt());
@ -891,7 +916,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testRandom() throws IOException {
int size = atLeast(100);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
KnnVectorValues vectors = vectorValues(size, dim);
int topK = 5;
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong());
@ -908,15 +933,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
TopDocs topDocs = actual.topDocs();
NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
if (vectorValue(vectors, j) != null && (acceptOrds == null || acceptOrds.get(j))) {
if (getVectorEncoding() == VectorEncoding.BYTE) {
assert query instanceof byte[];
expected.add(
j, similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(j)));
j, similarityFunction.compare((byte[]) query, (byte[]) vectorValue(vectors, j)));
} else {
assert query instanceof float[];
expected.add(
j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j)));
j, similarityFunction.compare((float[]) query, (float[]) vectorValue(vectors, j)));
}
if (expected.size() > topK) {
expected.pop();
@ -940,7 +963,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
throws IOException, ExecutionException, InterruptedException, TimeoutException {
int size = atLeast(100);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
KnnVectorValues vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
@ -1004,7 +1027,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
public void testConcurrentMergeBuilder() throws IOException {
int size = atLeast(1000);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
KnnVectorValues vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
TaskExecutor taskExecutor = new TaskExecutor(exec);
@ -1033,7 +1056,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
// Search for a large number of results
int topK = size - 1;
AbstractMockVectorValues<T> docVectors = vectorValues(size, dim);
KnnVectorValues docVectors = vectorValues(size, dim);
HnswGraph graph =
HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong())
.build(size);
@ -1047,8 +1070,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
};
AbstractMockVectorValues<T> queryVectors = vectorValues(1, dim);
RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0));
KnnVectorValues queryVectors = vectorValues(1, dim);
RandomVectorScorer queryScorer = buildScorer(docVectors, vectorValue(queryVectors, 0));
KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE);
HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null);
@ -1076,8 +1099,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class CircularFloatVectorValues extends FloatVectorValues {
private final int size;
private final float[] value;
@ -1103,22 +1125,18 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return size;
}
@Override
public float[] vectorValue() {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
return advance(doc + 1);
}
@Override
public int advance(int target) {
if (target >= 0 && target < size) {
doc = target;
@ -1140,8 +1158,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues.Bytes {
static class CircularByteVectorValues extends ByteVectorValues {
private final int size;
private final float[] value;
private final byte[] bValue;
@ -1169,22 +1186,18 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return size;
}
@Override
public byte[] vectorValue() {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
return advance(doc + 1);
}
@Override
public int advance(int target) {
if (target >= 0 && target < size) {
doc = target;
@ -1227,27 +1240,25 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return neighbors;
}
void assertVectorsEqual(AbstractMockVectorValues<T> u, AbstractMockVectorValues<T> v)
throws IOException {
void assertVectorsEqual(KnnVectorValues u, KnnVectorValues v) throws IOException {
int uDoc, vDoc;
while (true) {
uDoc = u.nextDoc();
vDoc = v.nextDoc();
assertEquals(u.size(), v.size());
for (int ord = 0; ord < u.size(); ord++) {
uDoc = u.ordToDoc(ord);
vDoc = v.ordToDoc(ord);
assertEquals(uDoc, vDoc);
if (uDoc == NO_MORE_DOCS) {
break;
}
assertNotEquals(NO_MORE_DOCS, uDoc);
switch (getVectorEncoding()) {
case BYTE ->
assertArrayEquals(
"vectors do not match for doc=" + uDoc,
(byte[]) u.vectorValue(),
(byte[]) v.vectorValue());
(byte[]) vectorValue(u, ord),
(byte[]) vectorValue(v, ord));
case FLOAT32 ->
assertArrayEquals(
"vectors do not match for doc=" + uDoc,
(float[]) u.vectorValue(),
(float[]) v.vectorValue(),
(float[]) vectorValue(u, ord),
(float[]) vectorValue(v, ord),
1e-4f);
default ->
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());

View File

@ -17,11 +17,17 @@
package org.apache.lucene.util.hnsw;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
implements RandomAccessVectorValues.Bytes {
class MockByteVectorValues extends ByteVectorValues {
private final int dimension;
private final byte[][] denseValues;
protected final byte[][] values;
private final int numVectors;
private final BytesRef binaryValue;
private final byte[] scratch;
static MockByteVectorValues fromValues(byte[][] values) {
@ -43,10 +49,26 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
}
MockByteVectorValues(byte[][] values, int dimension, byte[][] denseValues, int numVectors) {
super(values, dimension, denseValues, numVectors);
this.dimension = dimension;
this.values = values;
this.denseValues = denseValues;
this.numVectors = numVectors;
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension);
binaryValue.length = dimension;
scratch = new byte[dimension];
}
@Override
public int size() {
return values.length;
}
@Override
public int dimension() {
return dimension;
}
@Override
public MockByteVectorValues copy() {
return new MockByteVectorValues(
@ -55,20 +77,20 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
@Override
public byte[] vectorValue(int ord) {
return values[ord];
}
@Override
public byte[] vectorValue() {
if (LuceneTestCase.random().nextBoolean()) {
return values[pos];
return values[ord];
} else {
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
// twice in a
// single computation.
System.arraycopy(values[pos], 0, scratch, 0, dimension);
System.arraycopy(values[ord], 0, scratch, 0, dimension);
return scratch;
}
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
}

View File

@ -17,11 +17,15 @@
package org.apache.lucene.util.hnsw;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
class MockVectorValues extends AbstractMockVectorValues<float[]>
implements RandomAccessVectorValues.Floats {
class MockVectorValues extends FloatVectorValues {
private final int dimension;
private final float[][] denseValues;
protected final float[][] values;
private final int numVectors;
private final float[] scratch;
static MockVectorValues fromValues(float[][] values) {
@ -43,10 +47,23 @@ class MockVectorValues extends AbstractMockVectorValues<float[]>
}
MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) {
super(values, dimension, denseValues, numVectors);
this.dimension = dimension;
this.values = values;
this.denseValues = denseValues;
this.numVectors = numVectors;
this.scratch = new float[dimension];
}
@Override
public int size() {
return values.length;
}
@Override
public int dimension() {
return dimension;
}
@Override
public MockVectorValues copy() {
return new MockVectorValues(
@ -54,20 +71,20 @@ class MockVectorValues extends AbstractMockVectorValues<float[]>
}
@Override
public float[] vectorValue() {
public float[] vectorValue(int ord) {
if (LuceneTestCase.random().nextBoolean()) {
return values[pos];
return values[ord];
} else {
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
// This should help us catch cases of aliasing where the same vector values source is used
// twice in a single computation.
System.arraycopy(values[pos], 0, scratch, 0, dimension);
System.arraycopy(values[ord], 0, scratch, 0, dimension);
return scratch;
}
}
@Override
public float[] vectorValue(int targetOrd) {
return denseValues[targetOrd];
public DocIndexIterator iterator() {
return createDenseIterator();
}
}

View File

@ -17,13 +17,12 @@
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -56,7 +55,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
}
@Override
AbstractMockVectorValues<byte[]> vectorValues(int size, int dimension) {
MockByteVectorValues vectorValues(int size, int dimension) {
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
}
@ -65,7 +64,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
}
@Override
AbstractMockVectorValues<byte[]> vectorValues(float[][] values) {
MockByteVectorValues vectorValues(float[][] values) {
byte[][] bValues = new byte[values.length][];
// The case when all floats fit within a byte already.
boolean scaleSimple = fitsInByte(values[0][0]);
@ -86,42 +85,35 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
}
@Override
AbstractMockVectorValues<byte[]> vectorValues(
int size,
int dimension,
AbstractMockVectorValues<byte[]> pregeneratedVectorValues,
int pregeneratedOffset) {
MockByteVectorValues vectorValues(
int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) {
MockByteVectorValues pvv = (MockByteVectorValues) pregeneratedVectorValues;
byte[][] vectors = new byte[size][];
byte[][] randomVectors =
createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random());
byte[][] randomVectors = createRandomByteVectors(size - pvv.values.length, dimension, random());
for (int i = 0; i < pregeneratedOffset; i++) {
vectors[i] = randomVectors[i];
}
int currentDoc;
while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) {
vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd];
}
for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
i < vectors.length;
i++) {
vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) {
vectors[i] = randomVectors[i - pvv.values.length];
}
return MockByteVectorValues.fromValues(vectors);
}
@Override
AbstractMockVectorValues<byte[]> vectorValues(LeafReader reader, String fieldName)
throws IOException {
MockByteVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException {
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
byte[][] vectors = new byte[reader.maxDoc()][];
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
vectors[vectorValues.docID()] =
ArrayUtil.copyOfSubArray(
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
for (int i = 0; i < vectorValues.size(); i++) {
vectors[vectorValues.ordToDoc(i)] =
ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension());
}
return MockByteVectorValues.fromValues(vectors);
}

View File

@ -17,13 +17,12 @@
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -60,52 +59,44 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
}
@Override
AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
MockVectorValues vectorValues(int size, int dimension) {
return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random()));
}
@Override
AbstractMockVectorValues<float[]> vectorValues(float[][] values) {
MockVectorValues vectorValues(float[][] values) {
return MockVectorValues.fromValues(values);
}
@Override
AbstractMockVectorValues<float[]> vectorValues(LeafReader reader, String fieldName)
throws IOException {
MockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException {
FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName);
float[][] vectors = new float[reader.maxDoc()][];
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
vectors[vectorValues.docID()] =
ArrayUtil.copyOfSubArray(
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
for (int i = 0; i < vectorValues.size(); i++) {
vectors[vectorValues.ordToDoc(i)] =
ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension());
}
return MockVectorValues.fromValues(vectors);
}
@Override
AbstractMockVectorValues<float[]> vectorValues(
int size,
int dimension,
AbstractMockVectorValues<float[]> pregeneratedVectorValues,
int pregeneratedOffset) {
MockVectorValues vectorValues(
int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) {
MockVectorValues pvv = (MockVectorValues) pregeneratedVectorValues;
float[][] vectors = new float[size][];
float[][] randomVectors =
createRandomFloatVectors(
size - pregeneratedVectorValues.values.length, dimension, random());
createRandomFloatVectors(size - pvv.values.length, dimension, random());
for (int i = 0; i < pregeneratedOffset; i++) {
vectors[i] = randomVectors[i];
}
int currentDoc;
while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) {
vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd];
}
for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
i < vectors.length;
i++) {
vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) {
vectors[i] = randomVectors[i - pvv.values.length];
}
return MockVectorValues.fromValues(vectors);
@ -129,7 +120,7 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
RandomAccessVectorValues.Floats vectors = circularVectorValues(nDoc);
FloatVectorValues vectors = circularVectorValues(nDoc);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());

View File

@ -138,12 +138,6 @@ public class TestHnswUtil extends LuceneTestCase {
}
}
MockGraph graph = new MockGraph(nodes);
/**/
if (i == 2) {
System.out.println("iter " + i);
System.out.print(graph.toString());
}
/**/
assertEquals(isRooted(nodes), HnswUtil.isRooted(graph));
}
}

View File

@ -59,8 +59,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
@ -92,8 +91,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectorsNormalized(
@ -129,8 +127,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
@ -162,8 +159,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(
@ -242,11 +238,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
float[][] floats, Set<Integer> deletedVectors) {
return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats, deletedVectors) {
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= floats.length) {
throw new IOException("Current doc not set or too many iterations");
}
float[] v = ArrayUtil.copyArray(floats[curDoc]);
public float[] vectorValue(int ord) throws IOException {
float[] v = ArrayUtil.copyArray(floats[ordToDoc[ord]]);
VectorUtil.l2normalize(v);
return v;
}

View File

@ -272,14 +272,27 @@ public class TestScalarQuantizer extends LuceneTestCase {
static class TestSimpleFloatVectorValues extends FloatVectorValues {
protected final float[][] floats;
protected final Set<Integer> deletedVectors;
protected final int[] ordToDoc;
protected final int numLiveVectors;
protected int curDoc = -1;
TestSimpleFloatVectorValues(float[][] values, Set<Integer> deletedVectors) {
this.floats = values;
this.deletedVectors = deletedVectors;
this.numLiveVectors =
numLiveVectors =
deletedVectors == null ? values.length : values.length - deletedVectors.size();
ordToDoc = new int[numLiveVectors];
if (deletedVectors == null) {
for (int i = 0; i < numLiveVectors; i++) {
ordToDoc[i] = i;
}
} else {
int ord = 0;
for (int doc = 0; doc < values.length; doc++) {
if (!deletedVectors.contains(doc)) {
ordToDoc[ord++] = doc;
}
}
}
}
@Override
@ -293,40 +306,64 @@ public class TestScalarQuantizer extends LuceneTestCase {
}
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= floats.length) {
throw new IOException("Current doc not set or too many iterations");
}
return floats[curDoc];
public float[] vectorValue(int ord) throws IOException {
return floats[ordToDoc(ord)];
}
@Override
public int docID() {
if (curDoc >= floats.length) {
return NO_MORE_DOCS;
public int ordToDoc(int ord) {
return ordToDoc[ord];
}
return curDoc;
@Override
public DocIndexIterator iterator() {
return new DocIndexIterator() {
int ord = -1;
int doc = -1;
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
while (++curDoc < floats.length) {
if (deletedVectors == null || !deletedVectors.contains(curDoc)) {
return curDoc;
while (doc < floats.length - 1) {
++doc;
if (deletedVectors == null || !deletedVectors.contains(doc)) {
++ord;
return doc;
}
}
return docID();
return doc = NO_MORE_DOCS;
}
@Override
public int index() {
return ord;
}
@Override
public long cost() {
return floats.length - deletedVectors.size();
}
@Override
public int advance(int target) throws IOException {
curDoc = target - 1;
return nextDoc();
throw new UnsupportedOperationException();
}
};
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
@Override
public TestSimpleFloatVectorValues copy() {
return this;
}
}
}

View File

@ -2285,7 +2285,6 @@ public class MemoryIndex {
private static final class MemoryFloatVectorValues extends FloatVectorValues {
private final Info info;
private int currentDoc = -1;
MemoryFloatVectorValues(Info info) {
this.info = info;
@ -2302,14 +2301,19 @@ public class MemoryIndex {
}
@Override
public float[] vectorValue() {
if (currentDoc == 0) {
public float[] vectorValue(int ord) {
if (ord == 0) {
return info.floatVectorValues[0];
} else {
return null;
}
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public VectorScorer scorer(float[] query) {
if (query.length != info.fieldInfo.getVectorDimension()) {
@ -2320,50 +2324,31 @@ public class MemoryIndex {
+ info.fieldInfo.getVectorDimension());
}
MemoryFloatVectorValues vectorValues = new MemoryFloatVectorValues(info);
DocIndexIterator iterator = vectorValues.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
assert iterator.docID() == 0;
return info.fieldInfo
.getVectorSimilarityFunction()
.compare(vectorValues.vectorValue(), query);
.compare(vectorValues.vectorValue(0), query);
}
@Override
public DocIdSetIterator iterator() {
return vectorValues;
return iterator;
}
};
}
@Override
public int docID() {
return currentDoc;
}
@Override
public int nextDoc() {
int doc = ++currentDoc;
if (doc == 0) {
return doc;
} else {
return NO_MORE_DOCS;
}
}
@Override
public int advance(int target) {
if (target == 0) {
currentDoc = target;
return target;
} else {
return NO_MORE_DOCS;
}
public MemoryFloatVectorValues copy() {
return this;
}
}
private static final class MemoryByteVectorValues extends ByteVectorValues {
private final Info info;
private int currentDoc = -1;
MemoryByteVectorValues(Info info) {
this.info = info;
@ -2380,14 +2365,19 @@ public class MemoryIndex {
}
@Override
public byte[] vectorValue() {
if (currentDoc == 0) {
public byte[] vectorValue(int ord) {
if (ord == 0) {
return info.byteVectorValues[0];
} else {
return null;
}
}
@Override
public DocIndexIterator iterator() {
return createDenseIterator();
}
@Override
public VectorScorer scorer(byte[] query) {
if (query.length != info.fieldInfo.getVectorDimension()) {
@ -2398,44 +2388,26 @@ public class MemoryIndex {
+ info.fieldInfo.getVectorDimension());
}
MemoryByteVectorValues vectorValues = new MemoryByteVectorValues(info);
DocIndexIterator iterator = vectorValues.iterator();
return new VectorScorer() {
@Override
public float score() {
assert iterator.docID() == 0;
return info.fieldInfo
.getVectorSimilarityFunction()
.compare(vectorValues.vectorValue(), query);
.compare(vectorValues.vectorValue(0), query);
}
@Override
public DocIdSetIterator iterator() {
return vectorValues;
return iterator;
}
};
}
@Override
public int docID() {
return currentDoc;
}
@Override
public int nextDoc() {
int doc = ++currentDoc;
if (doc == 0) {
return doc;
} else {
return NO_MORE_DOCS;
}
}
@Override
public int advance(int target) {
if (target == 0) {
currentDoc = target;
return target;
} else {
return NO_MORE_DOCS;
}
public MemoryByteVectorValues copy() {
return this;
}
}
}

View File

@ -63,6 +63,7 @@ import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
@ -851,9 +852,10 @@ public class TestMemoryIndex extends LuceneTestCase {
.reader()
.getFloatVectorValues(fieldName);
assertNotNull(fvv);
assertEquals(0, fvv.nextDoc());
assertArrayEquals(expected, fvv.vectorValue(), 1e-6f);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, fvv.nextDoc());
KnnVectorValues.DocIndexIterator iterator = fvv.iterator();
assertEquals(0, iterator.nextDoc());
assertArrayEquals(expected, fvv.vectorValue(0), 1e-6f);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
}
private static void assertFloatVectorScore(
@ -868,7 +870,7 @@ public class TestMemoryIndex extends LuceneTestCase {
.getFloatVectorValues(fieldName);
assertNotNull(fvv);
if (random().nextBoolean()) {
fvv.nextDoc();
fvv.iterator().nextDoc();
}
VectorScorer scorer = fvv.scorer(queryVector);
assertEquals(0, scorer.iterator().nextDoc());
@ -886,9 +888,10 @@ public class TestMemoryIndex extends LuceneTestCase {
.reader()
.getByteVectorValues(fieldName);
assertNotNull(bvv);
assertEquals(0, bvv.nextDoc());
assertArrayEquals(expected, bvv.vectorValue());
assertEquals(DocIdSetIterator.NO_MORE_DOCS, bvv.nextDoc());
KnnVectorValues.DocIndexIterator iterator = bvv.iterator();
assertEquals(0, iterator.nextDoc());
assertArrayEquals(expected, bvv.vectorValue(0));
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
}
private static void assertByteVectorScore(
@ -903,7 +906,7 @@ public class TestMemoryIndex extends LuceneTestCase {
.getByteVectorValues(fieldName);
assertNotNull(bvv);
if (random().nextBoolean()) {
bvv.nextDoc();
bvv.iterator().nextDoc();
}
VectorScorer scorer = bvv.scorer(queryVector);
assertEquals(0, scorer.iterator().nextDoc());

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
@ -63,11 +64,12 @@ public class ByteKnnVectorFieldSource extends ValueSource {
}
return new VectorFieldFunction(this) {
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
@Override
public byte[] byteVectorVal(int doc) throws IOException {
if (exists(doc)) {
return vectorValues.vectorValue();
return vectorValues.vectorValue(iterator.index());
} else {
return null;
}
@ -75,7 +77,7 @@ public class ByteKnnVectorFieldSource extends ValueSource {
@Override
protected DocIdSetIterator getVectorIterator() {
return vectorValues;
return iterator;
}
};
}

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
@ -62,11 +63,12 @@ public class FloatKnnVectorFieldSource extends ValueSource {
}
return new VectorFieldFunction(this) {
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
@Override
public float[] floatVectorVal(int doc) throws IOException {
if (exists(doc)) {
return vectorValues.vectorValue();
return vectorValues.vectorValue(iterator.index());
} else {
return null;
}
@ -74,7 +76,7 @@ public class FloatKnnVectorFieldSource extends ValueSource {
@Override
protected DocIdSetIterator getVectorIterator() {
return vectorValues;
return iterator;
}
};
}

View File

@ -25,11 +25,11 @@ import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/** KMeans clustering algorithm for vectors */
public class KMeans {
@ -38,7 +38,7 @@ public class KMeans {
public static final int DEFAULT_ITRS = 10;
public static final int DEFAULT_SAMPLE_SIZE = 100_000;
private final RandomAccessVectorValues.Floats vectors;
private final FloatVectorValues vectors;
private final int numVectors;
private final int numCentroids;
private final Random random;
@ -57,9 +57,7 @@ public class KMeans {
* @throws IOException when if there is an error accessing vectors
*/
public static Results cluster(
RandomAccessVectorValues.Floats vectors,
VectorSimilarityFunction similarityFunction,
int numClusters)
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters)
throws IOException {
return cluster(
vectors,
@ -93,7 +91,7 @@ public class KMeans {
* @throws IOException if there is error accessing vectors
*/
public static Results cluster(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
int numClusters,
boolean assignCentroidsToVectors,
long seed,
@ -124,7 +122,7 @@ public class KMeans {
if (numClusters == 1) {
centroids = new float[1][vectors.dimension()];
} else {
RandomAccessVectorValues.Floats sampleVectors =
FloatVectorValues sampleVectors =
vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed);
KMeans kmeans =
new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters);
@ -142,7 +140,7 @@ public class KMeans {
}
private KMeans(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
int numCentroids,
Random random,
KmeansInitializationMethod initializationMethod,
@ -276,7 +274,7 @@ public class KMeans {
* @throws IOException if there is an error accessing vector values
*/
private static double runKMeansStep(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
float[][] centroids,
short[] docCentroids,
boolean useKahanSummation,
@ -348,9 +346,7 @@ public class KMeans {
* descending distance to the current centroid set
*/
static void assignCentroids(
RandomAccessVectorValues.Floats vectors,
float[][] centroids,
List<Integer> unassignedCentroidsIdxs)
FloatVectorValues vectors, float[][] centroids, List<Integer> unassignedCentroidsIdxs)
throws IOException {
int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()];
int assignedIndex = 0;

View File

@ -20,18 +20,18 @@ package org.apache.lucene.sandbox.codecs.quantization;
import java.io.IOException;
import java.util.Random;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/** A reader of vector values that samples a subset of the vectors. */
public class SampleReader implements RandomAccessVectorValues.Floats {
private final RandomAccessVectorValues.Floats origin;
public class SampleReader extends FloatVectorValues implements HasIndexSlice {
private final FloatVectorValues origin;
private final int sampleSize;
private final IntUnaryOperator sampleFunction;
SampleReader(
RandomAccessVectorValues.Floats origin, int sampleSize, IntUnaryOperator sampleFunction) {
SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) {
this.origin = origin;
this.sampleSize = sampleSize;
this.sampleFunction = sampleFunction;
@ -48,13 +48,13 @@ public class SampleReader implements RandomAccessVectorValues.Floats {
}
@Override
public Floats copy() throws IOException {
public FloatVectorValues copy() throws IOException {
throw new IllegalStateException("Not supported");
}
@Override
public IndexInput getSlice() {
return origin.getSlice();
return ((HasIndexSlice) origin).getSlice();
}
@Override
@ -77,8 +77,7 @@ public class SampleReader implements RandomAccessVectorValues.Floats {
throw new IllegalStateException("Not supported");
}
public static SampleReader createSampleReader(
RandomAccessVectorValues.Floats origin, int k, long seed) {
public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
int[] samples = reservoirSample(origin.size(), k, seed);
return new SampleReader(origin, samples.length, i -> samples[i]);
}

View File

@ -20,9 +20,9 @@ package org.apache.lucene.sandbox.codecs.quantization;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
public class TestKMeans extends LuceneTestCase {
@ -32,7 +32,7 @@ public class TestKMeans extends LuceneTestCase {
int dims = random().nextInt(2, 20);
int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
RandomAccessVectorValues.Floats vectors = generateData(nVectors, dims, nClusters);
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
// default case
{
@ -75,7 +75,7 @@ public class TestKMeans extends LuceneTestCase {
// nClusters > nVectors
int nClusters = 20;
int nVectors = 10;
RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters);
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
KMeans.Results results =
KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters);
// assert that we get 1 centroid, as nClusters will be adjusted
@ -87,7 +87,7 @@ public class TestKMeans extends LuceneTestCase {
int sampleSize = 2;
int nClusters = 2;
int nVectors = 300;
RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters);
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
KMeans.KmeansInitializationMethod initializationMethod =
KMeans.KmeansInitializationMethod.PLUS_PLUS;
KMeans.Results results =
@ -108,7 +108,7 @@ public class TestKMeans extends LuceneTestCase {
// test unassigned centroids
int nClusters = 4;
int nVectors = 400;
RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters);
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
KMeans.Results results =
KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters);
float[][] centroids = results.centroids();
@ -118,8 +118,7 @@ public class TestKMeans extends LuceneTestCase {
}
}
private static RandomAccessVectorValues.Floats generateData(
int nSamples, int nDims, int nClusters) {
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
List<float[]> vectors = new ArrayList<>(nSamples);
float[][] centroids = new float[nClusters][nDims];
// Generate random centroids
@ -137,6 +136,6 @@ public class TestKMeans extends LuceneTestCase {
}
vectors.add(vector);
}
return RandomAccessVectorValues.fromFloats(vectors, nDims);
return FloatVectorValues.fromFloats(vectors, nDims);
}
}

View File

@ -125,7 +125,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
FloatVectorValues floatValues = delegate.getFloatVectorValues(field);
assert floatValues != null;
assert floatValues.docID() == -1;
assert floatValues.iterator().docID() == -1;
assert floatValues.size() >= 0;
assert floatValues.dimension() > 0;
return floatValues;
@ -139,7 +139,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
ByteVectorValues values = delegate.getByteVectorValues(field);
assert values != null;
assert values.docID() == -1;
assert values.iterator().docID() == -1;
assert values.size() >= 0;
assert values.dimension() > 0;
return values;

View File

@ -55,6 +55,7 @@ import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MergePolicy;
@ -437,9 +438,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (IndexReader reader = DirectoryReader.open(w2)) {
LeafReader r = getOnlyLeafReader(reader);
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
assertEquals(0, vectorValues.nextDoc());
assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
assertEquals(0, iterator.nextDoc());
assertEquals(0, vectorValues.vectorValue(0)[0], 0);
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
}
}
}
@ -462,9 +464,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (IndexReader reader = DirectoryReader.open(w2)) {
LeafReader r = getOnlyLeafReader(reader);
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
assertNotEquals(NO_MORE_DOCS, vectorValues.nextDoc());
assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
assertNotEquals(NO_MORE_DOCS, iterator.nextDoc());
assertEquals(0, vectorValues.vectorValue(iterator.index())[0], 0);
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
}
}
}
@ -489,12 +492,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (IndexReader reader = DirectoryReader.open(w2)) {
LeafReader r = getOnlyLeafReader(reader);
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
assertEquals(0, vectorValues.nextDoc());
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
assertEquals(0, iterator.nextDoc());
// The merge order is randomized, we might get 0 first, or 1
float value = vectorValues.vectorValue()[0];
float value = vectorValues.vectorValue(0)[0];
assertTrue(value == 0 || value == 1);
assertEquals(1, vectorValues.nextDoc());
value += vectorValues.vectorValue()[0];
assertEquals(1, iterator.nextDoc());
value += vectorValues.vectorValue(1)[0];
assertEquals(1, value, 0);
}
}
@ -879,8 +883,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName);
if (byteVectorValues != null) {
docCount += byteVectorValues.size();
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += byteVectorValues.vectorValue()[0];
KnnVectorValues.DocIndexIterator iterator = byteVectorValues.iterator();
while (true) {
if (!(iterator.nextDoc() != NO_MORE_DOCS)) break;
checksum += byteVectorValues.vectorValue(iterator.index())[0];
}
}
}
@ -890,8 +896,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues != null) {
docCount += vectorValues.size();
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += vectorValues.vectorValue()[0];
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
while (true) {
if (!(iterator.nextDoc() != NO_MORE_DOCS)) break;
checksum += vectorValues.vectorValue(iterator.index())[0];
}
}
}
@ -950,10 +958,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertSame(iterator, scorer.iterator());
assertNotSame(iterator, scorer);
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break;
float score = scorer.score();
assertTrue(score >= 0f);
assertEquals(iterator.docID(), vectorValues.docID());
assertEquals(iterator.docID(), valuesIterator.docID());
}
// verify that a new scorer can be obtained after iteration
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
@ -1009,10 +1019,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertSame(iterator, scorer.iterator());
assertNotSame(iterator, scorer);
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break;
float score = scorer.score();
assertTrue(score >= 0f);
assertEquals(iterator.docID(), vectorValues.docID());
assertEquals(iterator.docID(), valuesIterator.docID());
}
// verify that a new scorer can be obtained after iteration
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
@ -1118,12 +1130,16 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
LeafReader r = getOnlyLeafReader(reader);
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
assertEquals(3, vectorValues.size());
vectorValues.nextDoc();
assertEquals(1, vectorValues.vectorValue()[0], 0);
vectorValues.nextDoc();
assertEquals(1, vectorValues.vectorValue()[0], 0);
vectorValues.nextDoc();
assertEquals(2, vectorValues.vectorValue()[0], 0);
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
iterator.nextDoc();
assertEquals(0, iterator.index());
assertEquals(1, vectorValues.vectorValue(0)[0], 0);
iterator.nextDoc();
assertEquals(1, iterator.index());
assertEquals(1, vectorValues.vectorValue(1)[0], 0);
iterator.nextDoc();
assertEquals(2, iterator.index());
assertEquals(2, vectorValues.vectorValue(2)[0], 0);
}
}
}
@ -1146,13 +1162,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
FloatVectorValues vectorValues = leaf.getFloatVectorValues(fieldName);
assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size());
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(-1f, vectorValues.vectorValue()[0], 0);
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue()[0], 0);
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
assertEquals("1", storedFields.document(iterator.nextDoc()).get("id"));
assertEquals(-1f, vectorValues.vectorValue(0)[0], 0);
assertEquals("2", storedFields.document(iterator.nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue(1)[0], 0);
assertEquals("4", storedFields.document(iterator.nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue(2)[0], 0);
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
}
}
}
@ -1175,13 +1192,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName);
assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size());
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(-1, vectorValues.vectorValue()[0], 0);
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue()[0], 0);
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
assertEquals("1", storedFields.document(vectorValues.iterator().nextDoc()).get("id"));
assertEquals(-1, vectorValues.vectorValue(0)[0], 0);
assertEquals("2", storedFields.document(vectorValues.iterator().nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue(1)[0], 0);
assertEquals("4", storedFields.document(vectorValues.iterator().nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue(2)[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc());
}
}
}
@ -1211,27 +1228,30 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1");
assertEquals(2, vectorValues.dimension());
assertEquals(2, vectorValues.size());
vectorValues.nextDoc();
assertEquals(1f, vectorValues.vectorValue()[0], 0);
vectorValues.nextDoc();
assertEquals(2f, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
iterator.nextDoc();
assertEquals(1f, vectorValues.vectorValue(0)[0], 0);
iterator.nextDoc();
assertEquals(2f, vectorValues.vectorValue(1)[0], 0);
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2");
KnnVectorValues.DocIndexIterator it2 = vectorValues2.iterator();
assertEquals(4, vectorValues2.dimension());
assertEquals(2, vectorValues2.size());
vectorValues2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
vectorValues2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc());
it2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue(0)[1], 0);
it2.nextDoc();
assertEquals(2f, vectorValues2.vectorValue(1)[1], 0);
assertEquals(NO_MORE_DOCS, it2.nextDoc());
FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3");
assertEquals(4, vectorValues3.dimension());
assertEquals(1, vectorValues3.size());
vectorValues3.nextDoc();
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
KnnVectorValues.DocIndexIterator it3 = vectorValues3.iterator();
it3.nextDoc();
assertEquals(1f, vectorValues3.vectorValue(0)[0], 0.1);
assertEquals(NO_MORE_DOCS, it3.nextDoc());
}
}
}
@ -1295,13 +1315,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
totalSize += vectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
float[] v = vectorValues.vectorValue();
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
while (true) {
if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break;
float[] v = vectorValues.vectorValue(iterator.index());
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
assertArrayEquals(idString, values[id], v, 0);
assertArrayEquals(idString + " " + docId, values[id], v, 0);
++valueCount;
} else {
++numDeletes;
@ -1375,8 +1397,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
totalSize += vectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] v = vectorValues.vectorValue();
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
while (true) {
if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break;
byte[] v = vectorValues.vectorValue(iterator.index());
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
@ -1495,8 +1519,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
StoredFields storedFields = ctx.reader().storedFields();
int docId;
int numLiveDocsWithVectors = 0;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
float[] v = vectorValues.vectorValue();
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
while (true) {
if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break;
float[] v = vectorValues.vectorValue(iterator.index());
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
@ -1703,25 +1729,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
int[] vectorDocs = new int[vectorValues.size() + 1];
int cur = -1;
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
while (++cur < vectorValues.size() + 1) {
vectorDocs[cur] = vectorValues.nextDoc();
vectorDocs[cur] = iterator.nextDoc();
if (cur != 0) {
assertTrue(vectorDocs[cur] > vectorDocs[cur - 1]);
}
}
vectorValues = r.getFloatVectorValues(fieldName);
DocIdSetIterator iter = vectorValues.iterator();
cur = -1;
for (int i = 0; i < numdocs; i++) {
// randomly advance to i
if (random().nextInt(4) == 3) {
while (vectorDocs[++cur] < i) {}
assertEquals(vectorDocs[cur], vectorValues.advance(i));
assertEquals(vectorDocs[cur], vectorValues.docID());
if (vectorValues.docID() == NO_MORE_DOCS) {
assertEquals(vectorDocs[cur], iter.advance(i));
assertEquals(vectorDocs[cur], iter.docID());
if (iter.docID() == NO_MORE_DOCS) {
break;
}
// make i equal to docid so that it is greater than docId in the next loop iteration
i = vectorValues.docID();
i = iter.docID();
}
}
}
@ -1772,6 +1800,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
double checksum = 0;
int docCount = 0;
long sumDocIds = 0;
long sumOrdToDocIds = 0;
switch (vectorEncoding) {
case BYTE -> {
for (LeafReaderContext ctx : r.leaves()) {
@ -1779,11 +1808,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (byteVectorValues != null) {
docCount += byteVectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += byteVectorValues.vectorValue()[0];
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) {
int ord = iter.index();
checksum += byteVectorValues.vectorValue(ord)[0];
Document doc = storedFields.document(iter.docID(), Set.of("id"));
sumDocIds += Integer.parseInt(doc.get("id"));
}
for (int ord = 0; ord < byteVectorValues.size(); ord++) {
Document doc =
storedFields.document(byteVectorValues.ordToDoc(ord), Set.of("id"));
sumOrdToDocIds += Integer.parseInt(doc.get("id"));
}
}
}
}
@ -1793,11 +1829,17 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (vectorValues != null) {
docCount += vectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += vectorValues.vectorValue()[0];
Document doc = storedFields.document(vectorValues.docID(), Set.of("id"));
KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) {
int ord = iter.index();
checksum += vectorValues.vectorValue(ord)[0];
Document doc = storedFields.document(iter.docID(), Set.of("id"));
sumDocIds += Integer.parseInt(doc.get("id"));
}
for (int ord = 0; ord < vectorValues.size(); ord++) {
Document doc = storedFields.document(vectorValues.ordToDoc(ord), Set.of("id"));
sumOrdToDocIds += Integer.parseInt(doc.get("id"));
}
}
}
}
@ -1809,6 +1851,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
assertEquals(fieldDocCount, docCount);
assertEquals(fieldSumDocIDs, sumDocIds);
assertEquals(fieldSumDocIDs, sumOrdToDocIds);
}
}
}
@ -1839,25 +1882,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
ByteVectorValues byteVectors = leafReader.getByteVectorValues("byte");
assertNotNull(byteVectors);
assertEquals(0, byteVectors.nextDoc());
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue());
assertEquals(1, byteVectors.nextDoc());
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue());
assertEquals(DocIdSetIterator.NO_MORE_DOCS, byteVectors.nextDoc());
KnnVectorValues.DocIndexIterator iter = byteVectors.iterator();
assertEquals(0, iter.nextDoc());
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(0));
assertEquals(1, iter.nextDoc());
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(1));
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc());
FloatVectorValues floatVectors = leafReader.getFloatVectorValues("float");
assertNotNull(floatVectors);
assertEquals(0, floatVectors.nextDoc());
float[] vector = floatVectors.vectorValue();
iter = floatVectors.iterator();
assertEquals(0, iter.nextDoc());
float[] vector = floatVectors.vectorValue(0);
assertEquals(2, vector.length);
assertEquals(1f, vector[0], 0f);
assertEquals(2f, vector[1], 0f);
assertEquals(1, floatVectors.nextDoc());
vector = floatVectors.vectorValue();
assertEquals(1, iter.nextDoc());
vector = floatVectors.vectorValue(1);
assertEquals(2, vector.length);
assertEquals(1f, vector[0], 0f);
assertEquals(2f, vector[1], 0f);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, floatVectors.nextDoc());
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc());
IOUtils.close(reader, w2, dir1, dir2);
}

View File

@ -183,7 +183,7 @@ public class AssertingScorer extends Scorer {
} else {
state = IteratorState.ITERATING;
}
assert in.docID() == advanced;
assert in.docID() == advanced : in.docID() + " != " + advanced + " in " + in;
assert AssertingScorer.this.in.docID() == in.docID();
return doc = advanced;
}